From 852f8567579112a9512e4540f993c19812c14f2a Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 18 Sep 2024 15:37:26 -0300 Subject: [PATCH 01/93] optimize add --- math/src/field/fields/mersenne31/field.rs | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index e4abfab0f..5bb0871c7 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -54,18 +54,9 @@ impl IsField for Mersenne31Field { /// Returns the sum of `a` and `b`. fn add(a: &u32, b: &u32) -> u32 { - // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 - // Working with i32 means we get a flag which informs us if overflow happens - let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); - let sum_u32 = sum_i32 as u32; - let sum_corr = sum_u32.wrapping_sub(MERSENNE_31_PRIME_FIELD_ORDER); - - //assert 31 bit clear - // If self + rhs did not overflow, return it. - // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1). - let sum = if over { sum_corr } else { sum_u32 }; - debug_assert!((sum >> 31) == 0); - Self::as_representative(&sum) + // We are using that if a and b are field elements of Mersenne31, then + // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p. + Self::weak_reduce(a + b) } /// Returns the multiplication of `a` and `b`. From 0d68798252e282ef5c6363d9cb3ce3c986504cb7 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 18 Sep 2024 18:48:31 -0300 Subject: [PATCH 02/93] save changes. Add, sub and mul checked --- math/benches/criterion_field.rs | 3 +- math/benches/fields/mersenne31.rs | 314 +++++++++++----------- math/src/field/fields/mersenne31/field.rs | 49 +++- 3 files changed, 199 insertions(+), 167 deletions(-) diff --git a/math/benches/criterion_field.rs b/math/benches/criterion_field.rs index 1c21822de..5738c9930 100644 --- a/math/benches/criterion_field.rs +++ b/math/benches/criterion_field.rs @@ -12,6 +12,7 @@ use fields::{ criterion_group!( name = field_benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = starkfield_ops_benchmarks, mersenne31_ops_benchmarks, mersenne31_mont_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks + targets = mersenne31_ops_benchmarks + //targets = starkfield_ops_benchmarks, mersenne31_ops_benchmarks, mersenne31_mont_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks ); criterion_main!(field_benches); diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index 99e3921a5..e1ea78767 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -24,79 +24,79 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { .collect::>(); let mut group = c.benchmark_group("Mersenne31 operations"); - for i in input.clone().into_iter() { - group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) + black_box(y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) * black_box(y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).pow(1_u64)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).square()); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).pow(2_u64)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x) * black_box(x)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input( - format!("pow {:?}", &i.len()), - &(i, 5u64), - |bench, (i, a)| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).pow(*a)); - } - }); - }, - ); - } + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, y) in i { + // black_box(black_box(x) + black_box(y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, y) in i { + // black_box(black_box(x) * black_box(y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).pow(1_u64)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).square()); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).pow(2_u64)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x) * black_box(x)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input( + // format!("pow {:?}", &i.len()), + // &(i, 5u64), + // |bench, (i, a)| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).pow(*a)); + // } + // }); + // }, + // ); + // } for i in input.clone().into_iter() { group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { @@ -108,88 +108,88 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { }); } - for i in input.clone().into_iter() { - group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).inv().unwrap()); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) / black_box(y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) == black_box(y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).sqrt()); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { - let i: Vec = i.iter().map(|(x, _)| x * x).collect(); - bench.iter(|| { - for x in &i { - black_box(black_box(x).sqrt()); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { - // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. - let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - bench.iter(|| { - for (x, y) in &i { - black_box(black_box(*x) & black_box(*y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { - let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - bench.iter(|| { - for (x, y) in &i { - black_box(black_box(*x) | black_box(*y)); - } - }); - }); - } - - for i in input.clone().into_iter() { - group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { - let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - bench.iter(|| { - for (x, y) in &i { - black_box(black_box(*x) ^ black_box(*y)); - } - }); - }); - } + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).inv().unwrap()); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, y) in i { + // black_box(black_box(x) / black_box(y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, y) in i { + // black_box(black_box(x) == black_box(y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { + // bench.iter(|| { + // for (x, _) in i { + // black_box(black_box(x).sqrt()); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { + // let i: Vec = i.iter().map(|(x, _)| x * x).collect(); + // bench.iter(|| { + // for x in &i { + // black_box(black_box(x).sqrt()); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { + // // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. + // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + // bench.iter(|| { + // for (x, y) in &i { + // black_box(black_box(*x) & black_box(*y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { + // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + // bench.iter(|| { + // for (x, y) in &i { + // black_box(black_box(*x) | black_box(*y)); + // } + // }); + // }); + // } + + // for i in input.clone().into_iter() { + // group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { + // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + // bench.iter(|| { + // for (x, y) in &i { + // black_box(black_box(*x) ^ black_box(*y)); + // } + // }); + // }); + // } } diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 5bb0871c7..97311db14 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -54,6 +54,19 @@ impl IsField for Mersenne31Field { /// Returns the sum of `a` and `b`. fn add(a: &u32, b: &u32) -> u32 { + // // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 + // // Working with i32 means we get a flag which informs us if overflow happens + // let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); + // let sum_u32 = sum_i32 as u32; + // let sum_corr = sum_u32.wrapping_sub(MERSENNE_31_PRIME_FIELD_ORDER); + + // //assert 31 bit clear + // // If self + rhs did not overflow, return it. + // // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1). + // let sum = if over { sum_corr } else { sum_u32 }; + // debug_assert!((sum >> 31) == 0); + // Self::as_representative(&sum) + // We are using that if a and b are field elements of Mersenne31, then // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p. Self::weak_reduce(a + b) @@ -73,6 +86,8 @@ impl IsField for Mersenne31Field { // Hence we need to remove the most significant bit and subtract 1. sub -= over as u32; sub & MERSENNE_31_PRIME_FIELD_ORDER + + // Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) } /// Returns the additive inverse of `a`. @@ -122,16 +137,18 @@ impl IsField for Mersenne31Field { /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> u32 { - let (lo, hi) = (x as u32 as u64, x >> 32); - // 2^32 = 2 (mod Mersenne 31 bit prime) - // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 - let t = lo + 2 * hi; + // let (lo, hi) = (x as u32 as u64, x >> 32); + // // 2^32 = 2 (mod Mersenne 31 bit prime) + // // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 + // let t = lo + 2 * hi; + + // const MASK: u64 = (1 << 31) - 1; + // let (lo, hi) = ((t & MASK) as u32, (t >> 31) as u32); + // // 2^31 = 1 mod Mersenne31 + // // lo < 2^31, hi < 6, so lo + hi < 2^32. + // Self::weak_reduce(lo + hi) - const MASK: u64 = (1 << 31) - 1; - let (lo, hi) = ((t & MASK) as u32, (t >> 31) as u32); - // 2^31 = 1 mod Mersenne31 - // lo < 2^31, hi < 6, so lo + hi < 2^32. - Self::weak_reduce(lo + hi) + (((((x >> 31) + x + 1) >> 31) + x) & (MERSENNE_31_PRIME_FIELD_ORDER as u64)) as u32 } /// Takes as input an element of BaseType and returns the internal representation @@ -196,12 +213,18 @@ impl Display for FieldElement { mod tests { use super::*; type F = Mersenne31Field; + type FE = FieldElement; #[test] fn from_hex_for_b_is_11() { assert_eq!(F::from_hex("B").unwrap(), 11); } + #[test] + fn from_hex_for_b_is_11_v2() { + assert_eq!(FE::from_hex("B").unwrap(), FE::from(11)); + } + #[test] fn sum_delayed_reduction() { let up_to = u32::pow(2, 16); @@ -257,6 +280,14 @@ mod tests { assert_eq!(c, F::zero()); } + #[test] + fn max_order_plus_1_is_0_v2() { + assert_eq!( + FE::from((MERSENNE_31_PRIME_FIELD_ORDER - 1) as u64) + FE::from(1), + FE::from(0) + ); + } + #[test] fn comparing_13_and_13_are_equal() { let a = F::from_base_type(13); From 5ebc30adec6f44c9433b730b87c5f0966e164ee5 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 19 Sep 2024 12:50:15 -0300 Subject: [PATCH 03/93] fix tests --- math/src/field/fields/mersenne31/field.rs | 162 ++++++---------------- 1 file changed, 44 insertions(+), 118 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 97311db14..3c4f77784 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -54,19 +54,6 @@ impl IsField for Mersenne31Field { /// Returns the sum of `a` and `b`. fn add(a: &u32, b: &u32) -> u32 { - // // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 - // // Working with i32 means we get a flag which informs us if overflow happens - // let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); - // let sum_u32 = sum_i32 as u32; - // let sum_corr = sum_u32.wrapping_sub(MERSENNE_31_PRIME_FIELD_ORDER); - - // //assert 31 bit clear - // // If self + rhs did not overflow, return it. - // // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1). - // let sum = if over { sum_corr } else { sum_u32 }; - // debug_assert!((sum >> 31) == 0); - // Self::as_representative(&sum) - // We are using that if a and b are field elements of Mersenne31, then // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p. Self::weak_reduce(a + b) @@ -79,15 +66,7 @@ impl IsField for Mersenne31Field { } fn sub(a: &u32, b: &u32) -> u32 { - let (mut sub, over) = a.overflowing_sub(*b); - - // If we didn't overflow we have the correct value. - // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1. - // Hence we need to remove the most significant bit and subtract 1. - sub -= over as u32; - sub & MERSENNE_31_PRIME_FIELD_ORDER - - // Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) + Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) } /// Returns the additive inverse of `a`. @@ -126,7 +105,7 @@ impl IsField for Mersenne31Field { } /// Returns the additive neutral element. - fn zero() -> Self::BaseType { + fn zero() -> u32 { 0u32 } @@ -137,17 +116,6 @@ impl IsField for Mersenne31Field { /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> u32 { - // let (lo, hi) = (x as u32 as u64, x >> 32); - // // 2^32 = 2 (mod Mersenne 31 bit prime) - // // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 - // let t = lo + 2 * hi; - - // const MASK: u64 = (1 << 31) - 1; - // let (lo, hi) = ((t & MASK) as u32, (t >> 31) as u32); - // // 2^31 = 1 mod Mersenne31 - // // lo < 2^31, hi < 6, so lo + hi < 2^32. - // Self::weak_reduce(lo + hi) - (((((x >> 31) + x + 1) >> 31) + x) & (MERSENNE_31_PRIME_FIELD_ORDER as u64)) as u32 } @@ -213,7 +181,7 @@ impl Display for FieldElement { mod tests { use super::*; type F = Mersenne31Field; - type FE = FieldElement; + type FE = FieldElement; #[test] fn from_hex_for_b_is_11() { @@ -250,198 +218,156 @@ mod tests { #[test] fn one_plus_1_is_2() { - let a = F::one(); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, 2u32); + assert_eq!(FE::one() + FE::one(), FE::from(&2u32)); } #[test] fn neg_1_plus_1_is_0() { - let a = F::neg(&F::one()); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(-FE::one() + FE::one(), FE::zero()); } #[test] fn neg_1_plus_2_is_1() { - let a = F::neg(&F::one()); - let b = F::from_base_type(2u32); - let c = F::add(&a, &b); - assert_eq!(c, F::one()); + assert_eq!(-FE::one() + FE::from(&2u32), FE::one()); } #[test] fn max_order_plus_1_is_0() { - let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, F::zero()); - } - - #[test] - fn max_order_plus_1_is_0_v2() { assert_eq!( - FE::from((MERSENNE_31_PRIME_FIELD_ORDER - 1) as u64) + FE::from(1), + FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + FE::from(1), FE::from(0) ); } #[test] fn comparing_13_and_13_are_equal() { - let a = F::from_base_type(13); - let b = F::from_base_type(13); - assert_eq!(a, b); + assert_eq!(FE::from(&13u32), FE::from(13)); } #[test] fn comparing_13_and_8_they_are_not_equal() { - let a = F::from_base_type(13); - let b = F::from_base_type(8); - assert_ne!(a, b); + assert_ne!(FE::from(&13u32), FE::from(8)); } #[test] fn one_sub_1_is_0() { - let a = F::one(); - let b = F::one(); - let c = F::sub(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(FE::one() - FE::one(), FE::zero()); } #[test] fn zero_sub_1_is_order_minus_1() { - let a = F::zero(); - let b = F::one(); - let c = F::sub(&a, &b); - assert_eq!(c, MERSENNE_31_PRIME_FIELD_ORDER - 1); + assert_eq!( + FE::zero() - FE::one(), + FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + ); } #[test] fn neg_1_sub_neg_1_is_0() { - let a = F::neg(&F::one()); - let b = F::neg(&F::one()); - let c = F::sub(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(-FE::one() - (-FE::one()), FE::zero()); } #[test] - fn neg_1_sub_1_is_neg_1() { - let a = F::neg(&F::one()); - let b = F::zero(); - let c = F::sub(&a, &b); - assert_eq!(c, F::neg(&F::one())); + fn neg_1_sub_0_is_neg_1() { + assert_eq!(-FE::one() - FE::zero(), -FE::one()); } #[test] fn mul_neutral_element() { - let a = F::from_base_type(1); - let b = F::from_base_type(2); - let c = F::mul(&a, &b); - assert_eq!(c, F::from_base_type(2)); + assert_eq!(FE::one() * FE::from(&2u32), FE::from(&2u32)); } #[test] fn mul_2_3_is_6() { - let a = F::from_base_type(2); - let b = F::from_base_type(3); - assert_eq!(a * b, F::from_base_type(6)); + assert_eq!(FE::from(&2u32) * FE::from(&3u32), FE::from(&6u32)); } #[test] fn mul_order_neg_1() { - let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let b = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let c = F::mul(&a, &b); - assert_eq!(c, F::from_base_type(1)); + assert_eq!( + FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1) + * FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1), + FE::one() + ); } #[test] fn pow_p_neg_1() { assert_eq!( - F::pow(&F::from_base_type(2), MERSENNE_31_PRIME_FIELD_ORDER - 1), - F::one() + FE::pow(&FE::from(&2u32), MERSENNE_31_PRIME_FIELD_ORDER - 1), + FE::one() ) } #[test] fn inv_0_error() { - let result = F::inv(&F::zero()); + let result = FE::inv(&FE::zero()); assert!(matches!(result, Err(FieldError::InvZeroError))); } #[test] fn inv_2() { - let result = F::inv(&F::from_base_type(2u32)).unwrap(); + let result = FE::inv(&FE::from(&2u32)).unwrap(); // sage: 1 / F(2) = 1073741824 - assert_eq!(result, 1073741824); + assert_eq!(result, FE::from(1073741824)); } #[test] fn pow_2_3() { - assert_eq!(F::pow(&F::from_base_type(2), 3_u64), 8) + assert_eq!(FE::pow(&FE::from(&2u32), 3u64), FE::from(8)); } #[test] fn div_1() { - assert_eq!(F::div(&F::from_base_type(2), &F::from_base_type(1)), 2) + assert_eq!(FE::from(&2u32) / FE::from(&1u32), FE::from(&2u32)); } #[test] fn div_4_2() { - assert_eq!(F::div(&F::from_base_type(4), &F::from_base_type(2)), 2) + assert_eq!(FE::from(&4u32) / FE::from(&2u32), FE::from(&2u32)); } - // 1431655766 #[test] fn div_4_3() { // sage: F(4) / F(3) = 1431655766 - assert_eq!( - F::div(&F::from_base_type(4), &F::from_base_type(3)), - 1431655766 - ) + assert_eq!(FE::from(&4u32) / FE::from(&3u32), FE::from(1431655766)); } #[test] fn two_plus_its_additive_inv_is_0() { - let two = F::from_base_type(2); - - assert_eq!(F::add(&two, &F::neg(&two)), F::zero()) + assert_eq!(FE::from(&2u32) + (-FE::from(&2u32)), FE::zero()); } #[test] fn from_u64_test() { - let num = F::from_u64(1u64); - assert_eq!(num, F::one()); + assert_eq!(FE::from(1u64), FE::one()); } #[test] fn creating_a_field_element_from_its_representative_returns_the_same_element_1() { - let change = 1; - let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); - let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 1; + let f1 = FE::from(&change); + let f2 = FE::from(&FE::representative(&f1)); assert_eq!(f1, f2); } #[test] fn creating_a_field_element_from_its_representative_returns_the_same_element_2() { - let change = 8; - let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); - let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 8; + let f1 = FE::from(&change); + let f2 = FE::from(&FE::representative(&f1)); assert_eq!(f1, f2); } #[test] fn from_base_type_test() { - let b = F::from_base_type(1u32); - assert_eq!(b, F::one()); + assert_eq!(FE::from(&1u32), FE::one()); } #[cfg(feature = "std")] #[test] fn to_hex_test() { - let num = F::from_hex("B").unwrap(); - assert_eq!(F::to_hex(&num), "B"); + let num = FE::from_hex("B").unwrap(); + assert_eq!(FE::to_hex(&num), "B"); } } From 81439ae76ec2dad0368897b9d3dece702a0a35be Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 19 Sep 2024 16:14:44 -0300 Subject: [PATCH 04/93] add new inv --- math/src/field/fields/mersenne31/field.rs | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 3c4f77784..b15adc168 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -42,6 +42,33 @@ impl Mersenne31Field { // Delayed reduction Self::from_u64(iter.map(|x| (x as u64)).sum::()) } + + pub fn new_inv(x: &u32) -> u32 { + let mut a: u32 = 1; + let mut b: u32 = 0; + let mut y: u32 = x.clone(); + let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; + let q: u32 = 31; + let mut e: u32; + let mut temp: u64; + let mut temp2: u32; + + loop { + e = y.trailing_zeros(); + y = y / (2u64.pow(e) as u32); + temp = 2u64.pow(q.wrapping_sub(e)); + a = Self::from_u64(temp * a as u64); + if y == 1 { + return a; + }; + temp2 = a.wrapping_add(b); + b = a; + a = temp2; + temp2 = y.wrapping_add(z); + z = y; + y = temp2; + } + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -370,4 +397,18 @@ mod tests { let num = FE::from_hex("B").unwrap(); assert_eq!(FE::to_hex(&num), "B"); } + + #[test] + fn new_inverse_test() { + let x = FE::from(&823451u32); + + let x_inv_original = FE::inv(&x).unwrap(); + let x_inv_new = FE::from(&F::new_inv(&823451u32)); + + println!("Original: {:?}", x_inv_original); + + println!("Nueva: {:?}", x_inv_new); + + assert_eq!(x_inv_original, x_inv_new); + } } From 47cf7daca9dbb3210bd56a4e83a7be54857176ee Mon Sep 17 00:00:00 2001 From: Nicole Date: Thu, 19 Sep 2024 16:45:07 -0300 Subject: [PATCH 05/93] add mult by powers of two --- math/src/field/fields/mersenne31/field.rs | 103 ++++++++++++++++------ 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 97311db14..4dbe3d29a 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -54,6 +54,7 @@ impl IsField for Mersenne31Field { /// Returns the sum of `a` and `b`. fn add(a: &u32, b: &u32) -> u32 { + // // OLD VERSION: // // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 // // Working with i32 means we get a flag which informs us if overflow happens // let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); @@ -67,6 +68,7 @@ impl IsField for Mersenne31Field { // debug_assert!((sum >> 31) == 0); // Self::as_representative(&sum) + //NEW VERSION: // We are using that if a and b are field elements of Mersenne31, then // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p. Self::weak_reduce(a + b) @@ -79,15 +81,16 @@ impl IsField for Mersenne31Field { } fn sub(a: &u32, b: &u32) -> u32 { - let (mut sub, over) = a.overflowing_sub(*b); + // // OLD VERSION: + // let (mut sub, over) = a.overflowing_sub(*b); + // // If we didn't overflow we have the correct value. + // // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1. + // // Hence we need to remove the most significant bit and subtract 1. + // sub -= over as u32; + // sub & MERSENNE_31_PRIME_FIELD_ORDER - // If we didn't overflow we have the correct value. - // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1. - // Hence we need to remove the most significant bit and subtract 1. - sub -= over as u32; - sub & MERSENNE_31_PRIME_FIELD_ORDER - - // Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) + // NEW VERSION: + Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) } /// Returns the additive inverse of `a`. @@ -101,17 +104,27 @@ impl IsField for Mersenne31Field { if *a == Self::zero() || *a == MERSENNE_31_PRIME_FIELD_ORDER { return Err(FieldError::InvZeroError); } - let p101 = Self::mul(&Self::pow(a, 4u32), a); - let p1111 = Self::mul(&Self::square(&p101), &p101); - let p11111111 = Self::mul(&Self::pow(&p1111, 16u32), &p1111); - let p111111110000 = Self::pow(&p11111111, 16u32); - let p111111111111 = Self::mul(&p111111110000, &p1111); - let p1111111111111111 = Self::mul(&Self::pow(&p111111110000, 16u32), &p11111111); - let p1111111111111111111111111111 = - Self::mul(&Self::pow(&p1111111111111111, 4096u32), &p111111111111); - let p1111111111111111111111111111101 = - Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); - Ok(p1111111111111111111111111111101) + // // OLD VERSION: + // let p101 = Self::mul(&Self::pow(a, 4u32), a); + // let p1111 = Self::mul(&Self::square(&p101), &p101); + // let p11111111 = Self::mul(&Self::pow(&p1111, 16u32), &p1111); + // let p111111110000 = Self::pow(&p11111111, 16u32); + // let p111111111111 = Self::mul(&p111111110000, &p1111); + // let p1111111111111111 = Self::mul(&Self::pow(&p111111110000, 16u32), &p11111111); + // let p1111111111111111111111111111 = + // Self::mul(&Self::pow(&p1111111111111111, 4096u32), &p111111111111); + // let p1111111111111111111111111111101 = + // Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); + // Ok(p1111111111111111111111111111101) + + // NEW VERSION: + let t0 = sqn(*a, 2) * a; + let t1 = t0 * t0 * t0; + let t2 = sqn(t1, 3) * t0; + let t3 = t2 * t2 * t0; + let t4 = sqn(t3, 8) * t3; + let t5 = sqn(t4, 8) * t3; + Ok(sqn(t5, 7) * t2) } /// Returns the division of `a` and `b`. @@ -137,6 +150,7 @@ impl IsField for Mersenne31Field { /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> u32 { + // // OLD VERSION: // let (lo, hi) = (x as u32 as u64, x >> 32); // // 2^32 = 2 (mod Mersenne 31 bit prime) // // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 @@ -148,6 +162,7 @@ impl IsField for Mersenne31Field { // // lo < 2^31, hi < 6, so lo + hi < 2^32. // Self::weak_reduce(lo + hi) + // NEW VERSION: (((((x >> 31) + x + 1) >> 31) + x) & (MERSENNE_31_PRIME_FIELD_ORDER as u64)) as u32 } @@ -158,6 +173,22 @@ impl IsField for Mersenne31Field { } } +/// Computes `a^(2*n)`. +pub fn sqn(mut a: u32, n: usize) -> u32 { + for _ in 0..n { + a = Mersenne31Field::mul(&a, &a); + } + a +} + +/// Computes a * 2^k, with |k| < 31 +pub fn mul_power_two(a: u32, k: i32) -> u64 { + // If a uses 32 bits, then a * 2^k uses 32 + k bits. + let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. + let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. + lsb as u64 + msb as u64 +} + impl IsPrimeField for Mersenne31Field { type RepresentativeType = u32; @@ -215,6 +246,24 @@ mod tests { type F = Mersenne31Field; type FE = FieldElement; + #[test] + fn mul_power_two_is_correct() { + let a = 3u32; + let k = 2; + let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); + let result = mul_power_two(a, k); + assert_eq!(FE::from(result), expected_result) + } + + #[test] + fn mul_power_two_is_correct_2() { + let a = 229287u32; + let k = 4; + let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); + let result = mul_power_two(a, k); + assert_eq!(FE::from(result), expected_result) + } + #[test] fn from_hex_for_b_is_11() { assert_eq!(F::from_hex("B").unwrap(), 11); @@ -283,7 +332,7 @@ mod tests { #[test] fn max_order_plus_1_is_0_v2() { assert_eq!( - FE::from((MERSENNE_31_PRIME_FIELD_ORDER - 1) as u64) + FE::from(1), + FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + FE::from(1), FE::from(0) ); } @@ -304,10 +353,10 @@ mod tests { #[test] fn one_sub_1_is_0() { - let a = F::one(); - let b = F::one(); - let c = F::sub(&a, &b); - assert_eq!(c, F::zero()); + let a = FE::one(); + let b = FE::one(); + let c = a - b; + assert_eq!(c, FE::zero()); } #[test] @@ -367,15 +416,15 @@ mod tests { #[test] fn inv_0_error() { - let result = F::inv(&F::zero()); + let result = FE::inv(&FE::zero()); assert!(matches!(result, Err(FieldError::InvZeroError))); } #[test] fn inv_2() { - let result = F::inv(&F::from_base_type(2u32)).unwrap(); + let result = FE::inv(&FE::from(2)).unwrap(); // sage: 1 / F(2) = 1073741824 - assert_eq!(result, 1073741824); + assert_eq!(result, FE::from(1073741824)); } #[test] From cfba8bc7d95016e8c1515bb057233a323f4d7937 Mon Sep 17 00:00:00 2001 From: Nicole Date: Thu, 19 Sep 2024 17:11:02 -0300 Subject: [PATCH 06/93] replace inverse --- math/src/field/fields/mersenne31/field.rs | 96 +++++++++++------------ 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index bb7e422d4..c3fd76a8d 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -42,33 +42,6 @@ impl Mersenne31Field { // Delayed reduction Self::from_u64(iter.map(|x| (x as u64)).sum::()) } - - pub fn new_inv(x: &u32) -> u32 { - let mut a: u32 = 1; - let mut b: u32 = 0; - let mut y: u32 = x.clone(); - let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; - let q: u32 = 31; - let mut e: u32; - let mut temp: u64; - let mut temp2: u32; - - loop { - e = y.trailing_zeros(); - y = y / (2u64.pow(e) as u32); - temp = 2u64.pow(q.wrapping_sub(e)); - a = Self::from_u64(temp * a as u64); - if y == 1 { - return a; - }; - temp2 = a.wrapping_add(b); - b = a; - a = temp2; - temp2 = y.wrapping_add(z); - z = y; - y = temp2; - } - } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -103,8 +76,8 @@ impl IsField for Mersenne31Field { } /// Returns the multiplicative inverse of `a`. - fn inv(a: &u32) -> Result { - if *a == Self::zero() || *a == MERSENNE_31_PRIME_FIELD_ORDER { + fn inv(x: &u32) -> Result { + if *x == Self::zero() || *x == MERSENNE_31_PRIME_FIELD_ORDER { return Err(FieldError::InvZeroError); } // // OLD VERSION: @@ -120,14 +93,39 @@ impl IsField for Mersenne31Field { // Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); // Ok(p1111111111111111111111111111101) + // // OLD VERSION: + // let t0 = sqn(*x, 2) * x; + // let t1 = t0 * t0 * t0; + // let t2 = sqn(t1, 3) * t0; + // let t3 = t2 * t2 * t0; + // let t4 = sqn(t3, 8) * t3; + // let t5 = sqn(t4, 8) * t3; + // Ok(sqn(t5, 7) * t2) + // NEW VERSION: - let t0 = sqn(*a, 2) * a; - let t1 = t0 * t0 * t0; - let t2 = sqn(t1, 3) * t0; - let t3 = t2 * t2 * t0; - let t4 = sqn(t3, 8) * t3; - let t5 = sqn(t4, 8) * t3; - Ok(sqn(t5, 7) * t2) + let mut a: u32 = 1; + let mut b: u32 = 0; + let mut y: u32 = x.clone(); + let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; + let q: u32 = 31; + let mut e: u32; + let mut temp: u64; + let mut temp2: u32; + + loop { + e = y.trailing_zeros(); + y = y / (2u64.pow(e) as u32); + a = mul_power_two(a, q.wrapping_sub(e)); + if y == 1 { + return Ok(a); + }; + temp2 = a.wrapping_add(b); + b = a; + a = temp2; + temp2 = y.wrapping_add(z); + z = y; + y = temp2; + } } /// Returns the division of `a` and `b`. @@ -172,11 +170,11 @@ pub fn sqn(mut a: u32, n: usize) -> u32 { } /// Computes a * 2^k, with |k| < 31 -pub fn mul_power_two(a: u32, k: i32) -> u64 { +pub fn mul_power_two(a: u32, k: u32) -> u32 { // If a uses 32 bits, then a * 2^k uses 32 + k bits. let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. - lsb as u64 + msb as u64 + lsb + msb } impl IsPrimeField for Mersenne31Field { @@ -242,7 +240,7 @@ mod tests { let k = 2; let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); let result = mul_power_two(a, k); - assert_eq!(FE::from(result), expected_result) + assert_eq!(FE::from(&result), expected_result) } #[test] @@ -251,7 +249,7 @@ mod tests { let k = 4; let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); let result = mul_power_two(a, k); - assert_eq!(FE::from(result), expected_result) + assert_eq!(FE::from(&result), expected_result) } #[test] @@ -442,17 +440,17 @@ mod tests { assert_eq!(FE::to_hex(&num), "B"); } - #[test] - fn new_inverse_test() { - let x = FE::from(&823451u32); + // #[test] + // fn new_inverse_test() { + // let x = FE::from(&823451u32); - let x_inv_original = FE::inv(&x).unwrap(); - let x_inv_new = FE::from(&F::new_inv(&823451u32)); + // let x_inv_original = FE::inv(&x).unwrap(); + // let x_inv_new = &FE::from(&823451u32).inv().unwrap(); - println!("Original: {:?}", x_inv_original); + // println!("Original: {:?}", x_inv_original); - println!("Nueva: {:?}", x_inv_new); + // println!("Nueva: {:?}", x_inv_new); - assert_eq!(x_inv_original, x_inv_new); - } + // assert_eq!(x_inv_original, x_inv_new); + // } } From 01eba0d0c9bc144c3e5a699895c72649e9edbe19 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 20 Sep 2024 10:21:13 -0300 Subject: [PATCH 07/93] test new inv --- math/benches/fields/mersenne31.rs | 44 +++++++++++------------ math/src/field/fields/mersenne31/field.rs | 25 +++++++------ 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index e1ea78767..bc06f6fde 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -98,36 +98,36 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { // ); // } - for i in input.clone().into_iter() { - group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) - black_box(y)); - } - }); - }); - } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).inv().unwrap()); - // } - // }); - // }); - // } - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { + // group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { // bench.iter(|| { // for (x, y) in i { - // black_box(black_box(x) / black_box(y)); + // black_box(black_box(x) - black_box(y)); // } // }); // }); // } + for i in input.clone().into_iter() { + group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) / black_box(y)); + } + }); + }); + } + // for i in input.clone().into_iter() { // group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { // bench.iter(|| { diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index c3fd76a8d..ccabd23aa 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -42,6 +42,14 @@ impl Mersenne31Field { // Delayed reduction Self::from_u64(iter.map(|x| (x as u64)).sum::()) } + + /// Computes a * 2^k, with |k| < 31 + pub fn mul_power_two(a: u32, k: u32) -> u32 { + // If a uses 32 bits, then a * 2^k uses 32 + k bits. + let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. + let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. + lsb + msb + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -109,13 +117,12 @@ impl IsField for Mersenne31Field { let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; let q: u32 = 31; let mut e: u32; - let mut temp: u64; let mut temp2: u32; loop { e = y.trailing_zeros(); - y = y / (2u64.pow(e) as u32); - a = mul_power_two(a, q.wrapping_sub(e)); + y >>= e; + a = Self::mul_power_two(a, q.wrapping_sub(e)); if y == 1 { return Ok(a); }; @@ -169,14 +176,6 @@ pub fn sqn(mut a: u32, n: usize) -> u32 { a } -/// Computes a * 2^k, with |k| < 31 -pub fn mul_power_two(a: u32, k: u32) -> u32 { - // If a uses 32 bits, then a * 2^k uses 32 + k bits. - let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. - let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. - lsb + msb -} - impl IsPrimeField for Mersenne31Field { type RepresentativeType = u32; @@ -239,7 +238,7 @@ mod tests { let a = 3u32; let k = 2; let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); - let result = mul_power_two(a, k); + let result = F::mul_power_two(a, k); assert_eq!(FE::from(&result), expected_result) } @@ -248,7 +247,7 @@ mod tests { let a = 229287u32; let k = 4; let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); - let result = mul_power_two(a, k); + let result = F::mul_power_two(a, k); assert_eq!(FE::from(&result), expected_result) } From 60fd9819295de3c71269cf797b614655af801057 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 20 Sep 2024 11:33:28 -0300 Subject: [PATCH 08/93] modify old algorithm for inv --- math/src/field/fields/mersenne31/field.rs | 83 +++++++++++++---------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index ccabd23aa..935cf5a91 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -50,6 +50,12 @@ impl Mersenne31Field { let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. lsb + msb } + + pub fn pow_2(a: &u32, order: u32) -> u32 { + let mut res = a.clone(); + (0..order).for_each(|_| res = Self::square(&res)); + res + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -88,18 +94,18 @@ impl IsField for Mersenne31Field { if *x == Self::zero() || *x == MERSENNE_31_PRIME_FIELD_ORDER { return Err(FieldError::InvZeroError); } - // // OLD VERSION: - // let p101 = Self::mul(&Self::pow(a, 4u32), a); - // let p1111 = Self::mul(&Self::square(&p101), &p101); - // let p11111111 = Self::mul(&Self::pow(&p1111, 16u32), &p1111); - // let p111111110000 = Self::pow(&p11111111, 16u32); - // let p111111111111 = Self::mul(&p111111110000, &p1111); - // let p1111111111111111 = Self::mul(&Self::pow(&p111111110000, 16u32), &p11111111); - // let p1111111111111111111111111111 = - // Self::mul(&Self::pow(&p1111111111111111, 4096u32), &p111111111111); - // let p1111111111111111111111111111101 = - // Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); - // Ok(p1111111111111111111111111111101) + // OLD VERSION: + let p101 = Self::mul(&Self::pow_2(x, 2), x); + let p1111 = Self::mul(&Self::square(&p101), &p101); + let p11111111 = Self::mul(&Self::pow_2(&p1111, 4u32), &p1111); + let p111111110000 = Self::pow_2(&p11111111, 4u32); + let p111111111111 = Self::mul(&p111111110000, &p1111); + let p1111111111111111 = Self::mul(&Self::pow_2(&p111111110000, 4u32), &p11111111); + let p1111111111111111111111111111 = + Self::mul(&Self::pow_2(&p1111111111111111, 12u32), &p111111111111); + let p1111111111111111111111111111101 = + Self::mul(&Self::pow_2(&p1111111111111111111111111111, 3u32), &p101); + Ok(p1111111111111111111111111111101) // // OLD VERSION: // let t0 = sqn(*x, 2) * x; @@ -111,28 +117,28 @@ impl IsField for Mersenne31Field { // Ok(sqn(t5, 7) * t2) // NEW VERSION: - let mut a: u32 = 1; - let mut b: u32 = 0; - let mut y: u32 = x.clone(); - let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; - let q: u32 = 31; - let mut e: u32; - let mut temp2: u32; - - loop { - e = y.trailing_zeros(); - y >>= e; - a = Self::mul_power_two(a, q.wrapping_sub(e)); - if y == 1 { - return Ok(a); - }; - temp2 = a.wrapping_add(b); - b = a; - a = temp2; - temp2 = y.wrapping_add(z); - z = y; - y = temp2; - } + // let mut a: u32 = 1; + // let mut b: u32 = 0; + // let mut y: u32 = x.clone(); + // let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; + // let q: u32 = 31; + // let mut e: u32; + // let mut temp2: u32; + + // loop { + // e = y.trailing_zeros(); + // y >>= e; + // a = Self::mul_power_two(a, q - e); + // if y == 1 { + // return Ok(a); + // }; + // temp2 = a.wrapping_add(b); + // b = a; + // a = temp2; + // temp2 = y.wrapping_add(z); + // z = y; + // y = temp2; + // } } /// Returns the division of `a` and `b`. @@ -251,6 +257,15 @@ mod tests { assert_eq!(FE::from(&result), expected_result) } + #[test] + fn pow_2_is_correct() { + let a = 3u32; + let order = 12; + let result = F::pow_2(&a, order); + let expected_result = FE::pow(&FE::from(&a), 4096u32); + assert_eq!(FE::from(&result), expected_result) + } + #[test] fn from_hex_for_b_is_11() { assert_eq!(F::from_hex("B").unwrap(), 11); From 996e2247da11a18666d5b7cad8c258493639ae5c Mon Sep 17 00:00:00 2001 From: Nicole Date: Fri, 20 Sep 2024 15:59:02 -0300 Subject: [PATCH 09/93] fix tests extension --- math/src/field/fields/mersenne31/extension.rs | 209 +++++++++++------- math/src/field/fields/mersenne31/field.rs | 21 +- 2 files changed, 140 insertions(+), 90 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 3c89a2147..431c7309b 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -5,30 +5,28 @@ use crate::field::{ cubic::{CubicExtensionField, HasCubicNonResidue}, quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, }, - traits::IsField, + traits::{IsField, IsSubFieldOf}, }; use super::field::Mersenne31Field; +type FpE = FieldElement; + //Note: The inverse calculation in mersenne31/plonky3 differs from the default quadratic extension so I implemented the complex extension. ////////////////// #[derive(Clone, Debug)] -pub struct Mersenne31Complex; +pub struct Degree2ExtensionField; -impl IsField for Mersenne31Complex { +impl IsField for Degree2ExtensionField { //Elements represents a[0] = real, a[1] = imaginary - type BaseType = [FieldElement; 2]; + type BaseType = [FpE; 2]; /// Returns the component wise addition of `a` and `b` fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { [a[0] + b[0], a[1] + b[1]] } - //NOTE: THIS uses Gauss algorithm. Bench this against plonky 3 implementation to see what is faster. /// Returns the multiplication of `a` and `b` using the following - /// equation: - /// (a0 + a1 * t) * (b0 + b1 * t) = a0 * b0 + a1 * b1 * Self::residue() + (a0 * b1 + a1 * b0) * t - /// where `t.pow(2)` equals `Q::residue()`. fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { let a0b0 = a[0] * b[0]; let a1b1 = a[1] * b[1]; @@ -40,7 +38,7 @@ impl IsField for Mersenne31Complex { let [a0, a1] = a; let v0 = a0 * a1; let c0 = (a0 + a1) * (a0 - a1); - let c1 = v0 + v0; + let c1 = v0.double(); [c0, c1] } /// Returns the component wise subtraction of `a` and `b` @@ -55,13 +53,13 @@ impl IsField for Mersenne31Complex { /// Returns the multiplicative inverse of `a` fn inv(a: &Self::BaseType) -> Result { - let inv_norm = (a[0].pow(2_u64) + a[1].pow(2_u64)).inv()?; + let inv_norm = (a[0].square() + a[1].square()).inv()?; Ok([a[0] * inv_norm, -a[1] * inv_norm]) } /// Returns the division of `a` and `b` fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - Self::mul(a, &Self::inv(b).unwrap()) + ::mul(a, &Self::inv(b).unwrap()) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -71,17 +69,17 @@ impl IsField for Mersenne31Complex { /// Returns the additive neutral element of the field extension. fn zero() -> Self::BaseType { - [FieldElement::zero(), FieldElement::zero()] + [FpE::zero(), FpE::zero()] } /// Returns the multiplicative neutral element of the field extension. fn one() -> Self::BaseType { - [FieldElement::one(), FieldElement::zero()] + [FpE::one(), FpE::zero()] } /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> Self::BaseType { - [FieldElement::from(x), FieldElement::zero()] + [FpE::from(x), FpE::zero()] } /// Takes as input an element of BaseType and returns the internal representation @@ -93,6 +91,55 @@ impl IsField for Mersenne31Complex { } } +impl IsSubFieldOf for Mersenne31Field { + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FpE::from(a) * b[0]; + let c1 = FpE::from(a) * b[1]; + [c0, c1] + } + + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::add(a, b[0].value())); + let c1 = FieldElement::from_raw(*b[1].value()); + [c0, c1] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree2ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::sub(a, b[0].value())); + let c1 = FieldElement::from_raw(::neg(b[1].value())); + [c0, c1] + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [FieldElement::from_raw(a), FieldElement::zero()] + } + + #[cfg(feature = "alloc")] + fn to_subfield_vec( + b: ::BaseType, + ) -> alloc::vec::Vec { + b.into_iter().map(|x| x.to_raw()).collect() + } +} + +/* pub type Mersenne31ComplexQuadraticExtensionField = QuadraticExtensionField; @@ -115,7 +162,9 @@ impl HasQuadraticNonResidue for Mersenne31Complex { ])) } } +*/ +/* pub type Mersenne31ComplexCubicExtensionField = CubicExtensionField; @@ -137,168 +186,174 @@ impl HasCubicNonResidue for Mersenne31Complex { ])) } } +*/ #[cfg(test)] mod tests { + use core::{num::FpCategory, ops::Neg}; + use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; use super::*; - type Fi = Mersenne31Complex; - type F = FieldElement; - - //NOTE: from_u64 reflects from_real - //NOTE: for imag use from_base_type + type Fp2E = FieldElement; #[test] fn add_real_one_plus_one_is_two() { - assert_eq!(Fi::add(&Fi::one(), &Fi::one()), Fi::from_u64(2)) + assert_eq!(Fp2E::one() + Fp2E::one(), Fp2E::from(2)) } #[test] fn add_real_neg_one_plus_one_is_zero() { - assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::one()), Fi::zero()) + assert_eq!(Fp2E::one() + Fp2E::one().neg(), Fp2E::zero()) } #[test] fn add_real_neg_one_plus_two_is_one() { - assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::from_u64(2)), Fi::one()) + assert_eq!(Fp2E::one().neg() + Fp2E::from(2), Fp2E::one()) } #[test] fn add_real_neg_one_plus_neg_one_is_order_sub_two() { assert_eq!( - Fi::add(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), - Fi::from_u64((MERSENNE_31_PRIME_FIELD_ORDER - 2).into()) + Fp2E::one().neg() + Fp2E::one().neg(), + Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)), FpE::zero()]) ) } #[test] fn add_complex_one_plus_one_two() { //Manually declare the complex part to one - let one = Fi::from_base_type([F::zero(), F::one()]); - let two = Fi::from_base_type([F::zero(), F::from(2)]); - assert_eq!(Fi::add(&one, &one), two) + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + let two_i = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(&one_i + &one_i, two_i) } #[test] fn add_complex_neg_one_plus_one_is_zero() { //Manually declare the complex part to one - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::add(&neg_one, &one), Fi::zero()) + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(neg_one_i + one_i, Fp2E::zero()) } #[test] fn add_complex_neg_one_plus_two_is_one() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - let two = Fi::from_base_type([F::zero(), F::from(2)]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::add(&neg_one, &two), one) + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); + let two_i = Fp2E::new([FpE::zero(), FpE::from(2)]); + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&neg_one_i + &two_i, one_i) } #[test] fn add_complex_neg_one_plus_neg_one_imag_is_order_sub_two() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); assert_eq!( - Fi::add(&neg_one, &neg_one)[1], - F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2) + (&neg_one_i + &neg_one_i).value()[1], + FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)) ) } #[test] fn add_order() { - let a = Fi::from_base_type([-F::one(), F::one()]); - let b = Fi::from_base_type([F::from(2), F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2)]); - let c = Fi::from_base_type([F::one(), -F::one()]); - assert_eq!(Fi::add(&a, &b), c) + let a = Fp2E::new([-FpE::one(), FpE::one()]); + let b = Fp2E::new([ + FpE::from(2), + FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)), + ]); + let c = Fp2E::new([FpE::one(), -FpE::one()]); + assert_eq!(&a + &b, c) } #[test] fn add_equal_zero() { - let a = Fi::from_base_type([-F::one(), -F::one()]); - let b = Fi::from_base_type([F::one(), F::one()]); - assert_eq!(Fi::add(&a, &b), Fi::zero()) + let a = Fp2E::new([-FpE::one(), -FpE::one()]); + let b = Fp2E::new([FpE::one(), FpE::one()]); + assert_eq!(&a + &b, Fp2E::zero()) } #[test] fn add_plus_one() { - let a = Fi::from_base_type([F::one(), F::from(2)]); - let b = Fi::from_base_type([F::one(), F::one()]); - let c = Fi::from_base_type([F::from(2), F::from(3)]); - assert_eq!(Fi::add(&a, &b), c) + let a = Fp2E::new([FpE::one(), FpE::from(2)]); + let b = Fp2E::new([FpE::one(), FpE::one()]); + let c = Fp2E::new([FpE::from(2), FpE::from(3)]); + assert_eq!(&a + &b, c) } #[test] fn sub_real_one_sub_one_is_zero() { - assert_eq!(Fi::sub(&Fi::one(), &Fi::one()), Fi::zero()) + assert_eq!(&Fp2E::one() - &Fp2E::one(), Fp2E::zero()) } #[test] fn sub_real_two_sub_two_is_zero() { - assert_eq!( - Fi::sub(&Fi::from_u64(2u64), &Fi::from_u64(2u64)), - Fi::zero() - ) + assert_eq!(&Fp2E::from(2) - &Fp2E::from(2), Fp2E::zero()) } #[test] fn sub_real_neg_one_sub_neg_one_is_zero() { - assert_eq!( - Fi::sub(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), - Fi::zero() - ) + assert_eq!(Fp2E::one().neg() - Fp2E::one().neg(), Fp2E::zero()) } #[test] fn sub_real_two_sub_one_is_one() { - assert_eq!(Fi::sub(&Fi::from_u64(2), &Fi::one()), Fi::one()) + assert_eq!(Fp2E::from(2) - Fp2E::one(), Fp2E::one()) } #[test] fn sub_real_neg_one_sub_zero_is_neg_one() { - assert_eq!( - Fi::sub(&Fi::neg(&Fi::one()), &Fi::zero()), - Fi::neg(&Fi::one()) - ) + assert_eq!(Fp2E::one().neg() - Fp2E::zero(), Fp2E::one().neg()) } #[test] fn sub_complex_one_sub_one_is_zero() { - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::sub(&one, &one), Fi::zero()) + let one = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&one - &one, Fp2E::zero()) } #[test] fn sub_complex_two_sub_two_is_zero() { - let two = Fi::from_base_type([F::zero(), F::from(2)]); - assert_eq!(Fi::sub(&two, &two), Fi::zero()) + let two = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(&two - &two, Fp2E::zero()) } #[test] fn sub_complex_neg_one_sub_neg_one_is_zero() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - assert_eq!(Fi::sub(&neg_one, &neg_one), Fi::zero()) + let neg_one = Fp2E::new([FpE::zero(), -FpE::one()]); + assert_eq!(&neg_one - &neg_one, Fp2E::zero()) } #[test] fn sub_complex_two_sub_one_is_one() { - let two = Fi::from_base_type([F::zero(), F::from(2)]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::sub(&two, &one), one) + let two = Fp2E::new([FpE::zero(), FpE::from(2)]); + let one = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&two - &one, one) } #[test] fn sub_complex_neg_one_sub_zero_is_neg_one() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - assert_eq!(Fi::sub(&neg_one, &Fi::zero()), neg_one) + let neg_one = Fp2E::new([FpE::zero(), -FpE::one()]); + assert_eq!(&neg_one - &Fp2E::zero(), neg_one) } #[test] fn mul() { - let a = Fi::from_base_type([F::from(2), F::from(2)]); - let b = Fi::from_base_type([F::from(4), F::from(5)]); - let c = Fi::from_base_type([-F::from(2), F::from(18)]); - assert_eq!(Fi::mul(&a, &b), c) + let a = Fp2E::new([FpE::from(2), FpE::from(2)]); + let b = Fp2E::new([FpE::from(4), FpE::from(5)]); + let c = Fp2E::new([-FpE::from(2), FpE::from(18)]); + assert_eq!(&a * &b, c) + } + + #[test] + fn square_equals_mul_by_itself() { + let a = Fp2E::new([FpE::from(2), FpE::from(3)]); + assert_eq!(a.square(), &a * &a) + } + + #[test] + fn mul_base_field_with_degree_2_extension() { + let a = FpE::from(3); + let b = Fp2E::new([FpE::from(2), FpE::from(4)]); + assert_eq!(a * b, Fp2E::new([FpE::from(6), FpE::from(12)])) } } diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index ccabd23aa..86bdbdde5 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -166,6 +166,9 @@ impl IsField for Mersenne31Field { fn from_base_type(x: u32) -> u32 { Self::weak_reduce(x) } + fn double(a: &u32) -> u32 { + Self::weak_reduce(a << 1) + } } /// Computes `a^(2*n)`. @@ -439,17 +442,9 @@ mod tests { assert_eq!(FE::to_hex(&num), "B"); } - // #[test] - // fn new_inverse_test() { - // let x = FE::from(&823451u32); - - // let x_inv_original = FE::inv(&x).unwrap(); - // let x_inv_new = &FE::from(&823451u32).inv().unwrap(); - - // println!("Original: {:?}", x_inv_original); - - // println!("Nueva: {:?}", x_inv_new); - - // assert_eq!(x_inv_original, x_inv_new); - // } + #[test] + fn double_equals_add_itself() { + let a = FE::from(1234); + assert_eq!(a + a, a.double()) + } } From 28d8b0e63a6b86a37c5dfefc387a2054fb4f12d2 Mon Sep 17 00:00:00 2001 From: Nicole Date: Fri, 20 Sep 2024 17:21:22 -0300 Subject: [PATCH 10/93] add mul for degree 4 extension --- math/src/field/fields/mersenne31/extension.rs | 208 +++++++++++++++++- 1 file changed, 200 insertions(+), 8 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 431c7309b..dc3748eb5 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -1,11 +1,14 @@ -use crate::field::{ - element::FieldElement, - errors::FieldError, - extensions::{ - cubic::{CubicExtensionField, HasCubicNonResidue}, - quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, +use crate::{ + elliptic_curve::short_weierstrass::curves::bls12_381::field_extension::LevelTwoResidue, + field::{ + element::FieldElement, + errors::FieldError, + extensions::{ + cubic::{CubicExtensionField, HasCubicNonResidue}, + quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, + }, + traits::{IsField, IsSubFieldOf}, }, - traits::{IsField, IsSubFieldOf}, }; use super::field::Mersenne31Field; @@ -139,6 +142,153 @@ impl IsSubFieldOf for Mersenne31Field { } } +type Fp2E = FieldElement; + +/// Extension of degree 4 defined with lambdaworks quadratic extension to test the correctness of Degree4ExtensionField +#[derive(Debug, Clone)] +pub struct Mersenne31LevelTwoResidue; +impl HasQuadraticNonResidue for Mersenne31LevelTwoResidue { + fn residue() -> Fp2E { + Fp2E::new([FpE::from(2), FpE::one()]) + } +} +pub type Degree4ExtensionFieldV2 = + QuadraticExtensionField; +#[derive(Clone, Debug)] +pub struct Degree4ExtensionField; + +impl IsField for Degree4ExtensionField { + //Elements represents a[0] = real, a[1] = imaginary + type BaseType = [Fp2E; 2]; + + /// Returns the component wise addition of `a` and `b` + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] + &b[0], &a[1] + &b[1]] + } + + /// Returns the multiplication of `a` and `b` using the following + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + // a = a0 + a1 * u, b = b0 + b1 * u, where + // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc + let [a00, a01] = a[0].value(); + let [a10, a11] = a[1].value(); + let [b00, b01] = b[0].value(); + let [b10, b11] = b[1].value(); + + let c00 = a00 * b00 - a01 * b01 - a11 * b11 + (a10 * b10).double() - a10 * b11 - b10 * a11; + let c01 = a00 * b01 + a01 * b00 + a10 * b10 - (a10 * b11).double() + (b10 * a11).double(); + let c10 = a00 * b10 - a01 * b11 + a10 * b00 - b01 * a11; + let c11 = a00 * b11 + a01 * b10 + a10 * b01 + a11 * b00; + + [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] + } + + fn square(a: &Self::BaseType) -> Self::BaseType { + let [a0, a1] = a; + let v0 = a0 * a1; + let c0 = (a0 + a1) * (a0 - a1); + let c1 = v0.double(); + [c0, c1] + } + /// Returns the component wise subtraction of `a` and `b` + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] - &b[0], &a[1] - &b[1]] + } + + /// Returns the component wise negation of `a` + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-&a[0], -&a[1]] + } + + /// Returns the multiplicative inverse of `a` + fn inv(a: &Self::BaseType) -> Result { + let inv_norm = (a[0].square() + a[1].square()).inv()?; + Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) + } + + /// Returns the division of `a` and `b` + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + ::mul(a, &Self::inv(b).unwrap()) + } + + /// Returns a boolean indicating whether `a` and `b` are equal component wise. + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a[0] == b[0] && a[1] == b[1] + } + + /// Returns the additive neutral element of the field extension. + fn zero() -> Self::BaseType { + [Fp2E::zero(), Fp2E::zero()] + } + + /// Returns the multiplicative neutral element of the field extension. + fn one() -> Self::BaseType { + [Fp2E::one(), Fp2E::zero()] + } + + /// Returns the element `x * 1` where 1 is the multiplicative neutral element. + fn from_u64(x: u64) -> Self::BaseType { + [Fp2E::from(x), Fp2E::zero()] + } + + /// Takes as input an element of BaseType and returns the internal representation + /// of that element in the field. + /// Note: for this case this is simply the identity, because the components + /// already have correct representations. + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } +} + +/*impl IsSubFieldOf for Mersenne31Field { + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FpE::from(a) * b[0]; + let c1 = FpE::from(a) * b[1]; + [c0, c1] + } + + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::add(a, b[0].value())); + let c1 = FieldElement::from_raw(*b[1].value()); + [c0, c1] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree2ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::sub(a, b[0].value())); + let c1 = FieldElement::from_raw(::neg(b[1].value())); + [c0, c1] + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [FieldElement::from_raw(a), FieldElement::zero()] + } + + #[cfg(feature = "alloc")] + fn to_subfield_vec( + b: ::BaseType, + ) -> alloc::vec::Vec { + b.into_iter().map(|x| x.to_raw()).collect() + } +} +*/ + /* pub type Mersenne31ComplexQuadraticExtensionField = QuadraticExtensionField; @@ -197,6 +347,7 @@ mod tests { use super::*; type Fp2E = FieldElement; + type Fp4E = FieldElement; #[test] fn add_real_one_plus_one_is_two() { @@ -351,9 +502,50 @@ mod tests { } #[test] - fn mul_base_field_with_degree_2_extension() { + fn mul_fpe_by_fp2e() { let a = FpE::from(3); let b = Fp2E::new([FpE::from(2), FpE::from(4)]); assert_eq!(a * b, Fp2E::new([FpE::from(6), FpE::from(12)])) } + + #[test] + fn mul_fp4_is_correct() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + + let b = Fp4E::new([ + Fp2E::new([FpE::from(6), FpE::from(7)]), + Fp2E::new([FpE::from(8), FpE::from(9)]), + ]); + + let a2 = FieldElement::::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + + let b = FieldElement::::new([ + Fp2E::new([FpE::from(6), FpE::from(7)]), + Fp2E::new([FpE::from(8), FpE::from(9)]), + ]); + } + + #[test] + fn mul_fp4_by_zero_is_zero() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + assert_eq!(Fp4E::zero(), a * Fp4E::zero()) + } + + #[test] + fn mul_fp4_by_one_is_identity() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + assert_eq!(a, a.clone() * Fp4E::one()) + } } From 5e1f533c1f87980617b3ef9e8f2e9dcd0fa620cd Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 23 Sep 2024 16:06:55 -0300 Subject: [PATCH 11/93] add fp4 isField and isSubField operations and benchmarks --- math/benches/criterion_field.rs | 4 +- math/benches/fields/mersenne31.rs | 427 ++++++++++++------ math/src/field/fields/mersenne31/extension.rs | 252 +++++++---- math/src/field/fields/mersenne31/field.rs | 2 +- 4 files changed, 445 insertions(+), 240 deletions(-) diff --git a/math/benches/criterion_field.rs b/math/benches/criterion_field.rs index 5738c9930..8fe2f210c 100644 --- a/math/benches/criterion_field.rs +++ b/math/benches/criterion_field.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use pprof::criterion::{Output, PProfProfiler}; mod fields; -use fields::mersenne31::mersenne31_ops_benchmarks; +use fields::mersenne31::{mersenne31_extension_ops_benchmarks, mersenne31_ops_benchmarks}; use fields::mersenne31_montgomery::mersenne31_mont_ops_benchmarks; use fields::{ stark252::starkfield_ops_benchmarks, u64_goldilocks::u64_goldilocks_ops_benchmarks, @@ -12,7 +12,7 @@ use fields::{ criterion_group!( name = field_benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = mersenne31_ops_benchmarks + targets = mersenne31_extension_ops_benchmarks //targets = starkfield_ops_benchmarks, mersenne31_ops_benchmarks, mersenne31_mont_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks ); criterion_main!(field_benches); diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index bc06f6fde..d6b2f90ae 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -1,10 +1,21 @@ use std::hint::black_box; use criterion::Criterion; -use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use lambdaworks_math::{ + elliptic_curve::edwards::curves::bandersnatch::field, + field::{ + element::FieldElement, + fields::mersenne31::{ + extension::{Degree2ExtensionField, Degree4ExtensionField, Degree4ExtensionFieldV2}, + field::Mersenne31Field, + }, + }, +}; use rand::random; pub type F = FieldElement; +pub type Fp2E = FieldElement; +pub type Fp4E = FieldElement; #[inline(never)] #[no_mangle] @@ -17,6 +28,122 @@ pub fn rand_field_elements(num: usize) -> Vec<(F, F)> { result } +//TODO: Check if this is the correct way to bench. +pub fn rand_fp4e(num: usize) -> Vec<(Fp4E, Fp4E)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push(( + Fp4E::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + Fp4E::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + )); + } + result +} + +pub fn rand_fp4e_v2( + num: usize, +) -> Vec<( + FieldElement, + FieldElement, +)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push(( + FieldElement::::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + FieldElement::::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + )); + } + result +} + +pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1000000].into_iter().map(rand_fp4e).collect::>(); + let input_v2: Vec< + Vec<( + FieldElement, + FieldElement, + )>, + > = [1000000].into_iter().map(rand_fp4e_v2).collect::>(); + + let mut group = c.benchmark_group("Mersenne31 Fp4 operations"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Mul of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input_v2.clone().into_iter() { + group.bench_with_input(format!("Mul of Fp4 V2 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Square of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input_v2.clone().into_iter() { + group.bench_with_input( + format!("Square of Fp4 V2 {:?}", &i.len()), + &i, + |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }, + ); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Inv of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input_v2.clone().into_iter() { + group.bench_with_input(format!("Inv of Fp4 V2 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } +} + pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { let input: Vec> = [1, 10, 100, 1000, 10000, 100000, 1000000] .into_iter() @@ -24,89 +151,91 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { .collect::>(); let mut group = c.benchmark_group("Mersenne31 operations"); - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, y) in i { - // black_box(black_box(x) + black_box(y)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, y) in i { - // black_box(black_box(x) * black_box(y)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).pow(1_u64)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).square()); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).pow(2_u64)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x) * black_box(x)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input( - // format!("pow {:?}", &i.len()), - // &(i, 5u64), - // |bench, (i, a)| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).pow(*a)); - // } - // }); - // }, - // ); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, y) in i { - // black_box(black_box(x) - black_box(y)); - // } - // }); - // }); - // } + /* + for i in input.clone().into_iter() { + group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) + black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(1_u64)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(2_u64)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x) * black_box(x)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input( + format!("pow {:?}", &i.len()), + &(i, 5u64), + |bench, (i, a)| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(*a)); + } + }); + }, + ); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) - black_box(y)); + } + }); + }); + } + */ for i in input.clone().into_iter() { group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { @@ -128,68 +257,70 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { }); } - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, y) in i { - // black_box(black_box(x) == black_box(y)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { - // bench.iter(|| { - // for (x, _) in i { - // black_box(black_box(x).sqrt()); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { - // let i: Vec = i.iter().map(|(x, _)| x * x).collect(); - // bench.iter(|| { - // for x in &i { - // black_box(black_box(x).sqrt()); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { - // // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. - // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - // bench.iter(|| { - // for (x, y) in &i { - // black_box(black_box(*x) & black_box(*y)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { - // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - // bench.iter(|| { - // for (x, y) in &i { - // black_box(black_box(*x) | black_box(*y)); - // } - // }); - // }); - // } - - // for i in input.clone().into_iter() { - // group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { - // let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); - // bench.iter(|| { - // for (x, y) in &i { - // black_box(black_box(*x) ^ black_box(*y)); - // } - // }); - // }); - // } + /* + for i in input.clone().into_iter() { + group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) == black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { + let i: Vec = i.iter().map(|(x, _)| x * x).collect(); + bench.iter(|| { + for x in &i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { + // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) & black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) | black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) ^ black_box(*y)); + } + }); + }); + } + */ } diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index dc3748eb5..7e1bf2fb9 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -154,6 +154,9 @@ impl HasQuadraticNonResidue for Mersenne31LevelTwoResidue } pub type Degree4ExtensionFieldV2 = QuadraticExtensionField; + +/// I = 0 + 1 * i is the non-residue of Fp2 used to define Fp4. +pub const I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(0), FpE::const_from_raw(1)]); #[derive(Clone, Debug)] pub struct Degree4ExtensionField; @@ -166,6 +169,16 @@ impl IsField for Degree4ExtensionField { [&a[0] + &b[0], &a[1] + &b[1]] } + /// Returns the component wise subtraction of `a` and `b` + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] - &b[0], &a[1] - &b[1]] + } + + /// Returns the component wise negation of `a` + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-&a[0], -&a[1]] + } + /// Returns the multiplication of `a` and `b` using the following fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { // a = a0 + a1 * u, b = b0 + b1 * u, where @@ -175,8 +188,13 @@ impl IsField for Degree4ExtensionField { let [b00, b01] = b[0].value(); let [b10, b11] = b[1].value(); - let c00 = a00 * b00 - a01 * b01 - a11 * b11 + (a10 * b10).double() - a10 * b11 - b10 * a11; - let c01 = a00 * b01 + a01 * b00 + a10 * b10 - (a10 * b11).double() + (b10 * a11).double(); + let a10b10 = a10 * b10; + let a10b11 = a10 * b11; + let a11b10 = b10 * a11; + let a11b11 = a11 * b11; + + let c00 = a00 * b00 - a01 * b01 + a10b10.double() - a10b11 - a11b10 - (a11b11).double(); + let c01 = a00 * b01 + a01 * b00 + a10b10 + a10b11.double() + a11b10.double() - a11b11; let c10 = a00 * b10 - a01 * b11 + a10 * b00 - b01 * a11; let c11 = a00 * b11 + a01 * b10 + a10 * b01 + a11 * b00; @@ -184,25 +202,27 @@ impl IsField for Degree4ExtensionField { } fn square(a: &Self::BaseType) -> Self::BaseType { - let [a0, a1] = a; - let v0 = a0 * a1; - let c0 = (a0 + a1) * (a0 - a1); - let c1 = v0.double(); - [c0, c1] - } - /// Returns the component wise subtraction of `a` and `b` - fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - [&a[0] - &b[0], &a[1] - &b[1]] - } + // a = a0 + a1 * u, where + // a0 = a00 + a01 * i and a1 = a11 + a11 * i + let [a00, a01] = a[0].value(); + let [a10, a11] = a[1].value(); - /// Returns the component wise negation of `a` - fn neg(a: &Self::BaseType) -> Self::BaseType { - [-&a[0], -&a[1]] + let a10a10 = a10 * a10; + let a10a11 = a10 * a11; + let a11a11 = a11 * a11; + + let c00 = a00 * a00 - a01 * a01 + a10a10.double() - a10a11.double() - (a11a11).double(); + let c01 = (a00 * a01).double() + a10a10 + a10a11.double().double() - a11a11; + let c10 = (a00 * a10).double() - (a01 * a11).double(); + let c11 = (a00 * a11).double() + (a01 * a10).double(); + + [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] } /// Returns the multiplicative inverse of `a` fn inv(a: &Self::BaseType) -> Result { - let inv_norm = (a[0].square() + a[1].square()).inv()?; + let a1_square = a[1].square(); + let inv_norm = (a[0].square() - a1_square.double() - a1_square * I).inv()?; Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) } @@ -240,107 +260,73 @@ impl IsField for Degree4ExtensionField { } } -/*impl IsSubFieldOf for Mersenne31Field { +impl IsSubFieldOf for Mersenne31Field { fn mul( a: &Self::BaseType, - b: &::BaseType, - ) -> ::BaseType { - let c0 = FpE::from(a) * b[0]; - let c1 = FpE::from(a) * b[1]; + b: &::BaseType, + ) -> ::BaseType { + let c0 = FpE::from(a) * &b[0]; + let c1 = FpE::from(a) * &b[1]; [c0, c1] } fn add( a: &Self::BaseType, - b: &::BaseType, - ) -> ::BaseType { - let c0 = FieldElement::from_raw(::add(a, b[0].value())); + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(>::add( + a, + b[0].value(), + )); let c1 = FieldElement::from_raw(*b[1].value()); [c0, c1] } fn div( a: &Self::BaseType, - b: &::BaseType, - ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree4ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) } fn sub( a: &Self::BaseType, - b: &::BaseType, - ) -> ::BaseType { - let c0 = FieldElement::from_raw(::sub(a, b[0].value())); - let c1 = FieldElement::from_raw(::neg(b[1].value())); + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(>::sub( + a, + b[0].value(), + )); + let c1 = FieldElement::from_raw(::neg(b[1].value())); [c0, c1] } - fn embed(a: Self::BaseType) -> ::BaseType { - [FieldElement::from_raw(a), FieldElement::zero()] + fn embed(a: Self::BaseType) -> ::BaseType { + [ + Fp2E::from_raw(>::embed(a)), + Fp2E::zero(), + ] } #[cfg(feature = "alloc")] fn to_subfield_vec( - b: ::BaseType, + b: ::BaseType, ) -> alloc::vec::Vec { - b.into_iter().map(|x| x.to_raw()).collect() - } -} -*/ - -/* -pub type Mersenne31ComplexQuadraticExtensionField = - QuadraticExtensionField; - -//TODO: Check this should be for complex and not base field -impl HasQuadraticNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^2 - i - 2 - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::from(2), - FieldElement::::one(), - ])) + // TODO: Repace this for with a map similarly to this: + // b.into_iter().map(|x| x.to_raw()).collect() + let mut result = Vec::new(); + for fp2e in b { + result.push(fp2e.value()[0].to_raw()); + result.push(fp2e.value()[1].to_raw()); + } + result } } -*/ - -/* -pub type Mersenne31ComplexCubicExtensionField = - CubicExtensionField; - -impl HasCubicNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^3 - 5*i - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::zero(), - FieldElement::::from(5), - ])) - } -} -*/ #[cfg(test)] mod tests { - use core::{num::FpCategory, ops::Neg}; + use core::ops::Neg; use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; @@ -488,7 +474,7 @@ mod tests { } #[test] - fn mul() { + fn mul_fp2() { let a = Fp2E::new([FpE::from(2), FpE::from(2)]); let b = Fp2E::new([FpE::from(4), FpE::from(5)]); let c = Fp2E::new([-FpE::from(2), FpE::from(18)]); @@ -525,10 +511,37 @@ mod tests { Fp2E::new([FpE::from(4), FpE::from(5)]), ]); - let b = FieldElement::::new([ + let b2 = FieldElement::::new([ + Fp2E::new([FpE::from(6), FpE::from(7)]), + Fp2E::new([FpE::from(8), FpE::from(9)]), + ]); + + assert_eq!((&a * &b).value(), (a2 * b2).value()) + } + + #[test] + fn mul_fp4_is_correct_2() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), + Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), + ]); + + let b = Fp4E::new([ + Fp2E::new([FpE::from(6), FpE::from(7)]), + Fp2E::new([FpE::from(8), FpE::from(9)]), + ]); + + let a2 = FieldElement::::new([ + Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), + Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), + ]); + + let b2 = FieldElement::::new([ Fp2E::new([FpE::from(6), FpE::from(7)]), Fp2E::new([FpE::from(8), FpE::from(9)]), ]); + + assert_eq!((&a * &b).value(), (a2 * b2).value()) } #[test] @@ -548,4 +561,65 @@ mod tests { ]); assert_eq!(a, a.clone() * Fp4E::one()) } + + #[test] + fn square_fp4_is_correct() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + + let a2 = FieldElement::::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + + assert_eq!(a.square().value(), a2.square().value()) + } + + #[test] + fn square_fp4_equals_mul_two_times() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(3), FpE::from(4)]), + Fp2E::new([FpE::from(5), FpE::from(6)]), + ]); + + assert_eq!(a.square(), &a * &a) + } + + #[test] + fn fp4_mul_by_inv_is_one() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), + Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), + ]); + + assert_eq!(&a * a.inv().unwrap(), Fp4E::one()) + } + + #[test] + fn embed_fp_with_fp4() { + let a = FpE::from(3); + let a_extension = Fp4E::from(3); + assert_eq!(a.to_extension::(), a_extension); + } + + #[test] + fn add_fp_and_fp4() { + let a = FpE::from(3); + let a_extension = Fp4E::from(3); + let b = Fp4E::from(2); + assert_eq!(a + &b, a_extension + b); + } + + #[test] + fn mul_fp_by_fp4() { + let a = FpE::from(30000000000); + let a_extension = a.clone().to_extension::(); + let b = Fp4E::new([ + Fp2E::new([FpE::from(1), FpE::from(2)]), + Fp2E::new([FpE::from(3), FpE::from(4)]), + ]); + assert_eq!(a * &b, a_extension * b); + } } diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index bb3b59b4c..88015e0a7 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -43,7 +43,7 @@ impl Mersenne31Field { Self::from_u64(iter.map(|x| (x as u64)).sum::()) } - /// Computes a * 2^k, with |k| < 31 + /// Computes a * 2^k, with 0 < k < 31 pub fn mul_power_two(a: u32, k: u32) -> u32 { // If a uses 32 bits, then a * 2^k uses 32 + k bits. let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. From fde7faa6f2a8815f8316053a20c95968d661df40 Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 23 Sep 2024 17:38:20 -0300 Subject: [PATCH 12/93] new version for fp4 mul based on the paper --- math/src/field/fields/mersenne31/extension.rs | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 7e1bf2fb9..95f791abb 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -155,34 +155,33 @@ impl HasQuadraticNonResidue for Mersenne31LevelTwoResidue pub type Degree4ExtensionFieldV2 = QuadraticExtensionField; -/// I = 0 + 1 * i is the non-residue of Fp2 used to define Fp4. +/// I = 0 + 1 * i pub const I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(0), FpE::const_from_raw(1)]); + +pub const TWO_PLUS_I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(2), FpE::const_from_raw(1)]); #[derive(Clone, Debug)] pub struct Degree4ExtensionField; impl IsField for Degree4ExtensionField { - //Elements represents a[0] = real, a[1] = imaginary type BaseType = [Fp2E; 2]; - /// Returns the component wise addition of `a` and `b` fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { [&a[0] + &b[0], &a[1] + &b[1]] } - /// Returns the component wise subtraction of `a` and `b` fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { [&a[0] - &b[0], &a[1] - &b[1]] } - /// Returns the component wise negation of `a` fn neg(a: &Self::BaseType) -> Self::BaseType { [-&a[0], -&a[1]] } - /// Returns the multiplication of `a` and `b` using the following fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + /* + // VERSION 1 (distribution by hand): // a = a0 + a1 * u, b = b0 + b1 * u, where - // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc + // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc. let [a00, a01] = a[0].value(); let [a10, a11] = a[1].value(); let [b00, b01] = b[0].value(); @@ -199,6 +198,15 @@ impl IsField for Degree4ExtensionField { let c11 = a00 * b11 + a01 * b10 + a10 * b01 + a11 * b00; [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] + */ + + // VERSION 2 (paper): + let a0b0 = &a[0] * &b[0]; + let a1b1 = &a[1] * &b[1]; + [ + &a0b0 + TWO_PLUS_I * &a1b1, + (&a[0] + &a[1]) * (&b[0] + &b[1]) - a0b0 - a1b1, + ] } fn square(a: &Self::BaseType) -> Self::BaseType { @@ -219,42 +227,32 @@ impl IsField for Degree4ExtensionField { [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] } - /// Returns the multiplicative inverse of `a` fn inv(a: &Self::BaseType) -> Result { let a1_square = a[1].square(); let inv_norm = (a[0].square() - a1_square.double() - a1_square * I).inv()?; Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) } - /// Returns the division of `a` and `b` fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { ::mul(a, &Self::inv(b).unwrap()) } - /// Returns a boolean indicating whether `a` and `b` are equal component wise. fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { a[0] == b[0] && a[1] == b[1] } - /// Returns the additive neutral element of the field extension. fn zero() -> Self::BaseType { [Fp2E::zero(), Fp2E::zero()] } - /// Returns the multiplicative neutral element of the field extension. fn one() -> Self::BaseType { [Fp2E::one(), Fp2E::zero()] } - /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> Self::BaseType { [Fp2E::from(x), Fp2E::zero()] } - /// Takes as input an element of BaseType and returns the internal representation - /// of that element in the field. - /// Note: for this case this is simply the identity, because the components - /// already have correct representations. fn from_base_type(x: Self::BaseType) -> Self::BaseType { x } @@ -474,7 +472,7 @@ mod tests { } #[test] - fn mul_fp2() { + fn mul_fp2_is_correct() { let a = Fp2E::new([FpE::from(2), FpE::from(2)]); let b = Fp2E::new([FpE::from(4), FpE::from(5)]); let c = Fp2E::new([-FpE::from(2), FpE::from(18)]); From 60c8197c633d3b021dd842f02e3d4bd08591c4b6 Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 23 Sep 2024 17:53:51 -0300 Subject: [PATCH 13/93] add mul of a fp2e by non-residue --- math/src/field/fields/mersenne31/extension.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 95f791abb..5e4f1e45a 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -159,6 +159,13 @@ pub type Degree4ExtensionFieldV2 = pub const I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(0), FpE::const_from_raw(1)]); pub const TWO_PLUS_I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(2), FpE::const_from_raw(1)]); + +pub fn mul_fp2_by_nonresidue(a: &Fp2E) -> Fp2E { + Fp2E::new([ + a.value()[0].double() - a.value()[1], + &a.value()[1].double() + &a.value()[0], + ]) +} #[derive(Clone, Debug)] pub struct Degree4ExtensionField; @@ -204,7 +211,7 @@ impl IsField for Degree4ExtensionField { let a0b0 = &a[0] * &b[0]; let a1b1 = &a[1] * &b[1]; [ - &a0b0 + TWO_PLUS_I * &a1b1, + &a0b0 + mul_fp2_by_nonresidue(&a1b1), (&a[0] + &a[1]) * (&b[0] + &b[1]) - a0b0 - a1b1, ] } From 25b8869d96d83db2d7657a1c6a196ee41bb26589 Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 23 Sep 2024 18:23:28 -0300 Subject: [PATCH 14/93] change inv using mul_fp2_by_non_resiude --- math/src/field/fields/mersenne31/extension.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 5e4f1e45a..592e68648 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -235,8 +235,12 @@ impl IsField for Degree4ExtensionField { } fn inv(a: &Self::BaseType) -> Result { - let a1_square = a[1].square(); - let inv_norm = (a[0].square() - a1_square.double() - a1_square * I).inv()?; + // VERSION 1: + // let a1_square = a[1].square(); + // let inv_norm = (a[0].square() - a1_square.double() - a1_square * I).inv()?; + + // VERSION 2: + let inv_norm = (a[0].square() - mul_fp2_by_nonresidue(&a[1].square())).inv()?; Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) } From 1187be833bf19dd88a3d6d3db10ade13fdd4fcd8 Mon Sep 17 00:00:00 2001 From: Nicole Date: Tue, 24 Sep 2024 10:43:23 -0300 Subject: [PATCH 15/93] save work --- math/src/field/fields/mersenne31/extension.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 592e68648..10659142d 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -158,6 +158,7 @@ pub type Degree4ExtensionFieldV2 = /// I = 0 + 1 * i pub const I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(0), FpE::const_from_raw(1)]); +/// TWO_PLUS_I = 2 + 1 is the non-residue of Fp2 used for the Fp4 extension. pub const TWO_PLUS_I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(2), FpE::const_from_raw(1)]); pub fn mul_fp2_by_nonresidue(a: &Fp2E) -> Fp2E { From ecff11e7cccfbb643a9eebaaa4cae08fd72e30e8 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 24 Sep 2024 10:52:32 -0300 Subject: [PATCH 16/93] wip fp2 test --- math/src/field/fields/mersenne31/extension.rs | 129 ++++++++---------- 1 file changed, 55 insertions(+), 74 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 431c7309b..0904f2de2 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -1,10 +1,6 @@ use crate::field::{ element::FieldElement, errors::FieldError, - extensions::{ - cubic::{CubicExtensionField, HasCubicNonResidue}, - quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, - }, traits::{IsField, IsSubFieldOf}, }; @@ -92,39 +88,33 @@ impl IsField for Degree2ExtensionField { } impl IsSubFieldOf for Mersenne31Field { - fn mul( + fn add( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FpE::from(a) * b[0]; - let c1 = FpE::from(a) * b[1]; - [c0, c1] + [FpE::from(a) + b[0], FpE::from(a) + b[1]] } - fn add( + fn sub( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FieldElement::from_raw(::add(a, b[0].value())); - let c1 = FieldElement::from_raw(*b[1].value()); - [c0, c1] + [FpE::from(a) - b[0], FpE::from(a) - b[1]] } - fn div( + fn mul( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let b_inv = Degree2ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + [FpE::from(a) * b[0], FpE::from(a) * b[1]] } - fn sub( + fn div( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FieldElement::from_raw(::sub(a, b[0].value())); - let c1 = FieldElement::from_raw(::neg(b[1].value())); - [c0, c1] + let b_inv = Degree2ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) } fn embed(a: Self::BaseType) -> ::BaseType { @@ -139,58 +129,9 @@ impl IsSubFieldOf for Mersenne31Field { } } -/* -pub type Mersenne31ComplexQuadraticExtensionField = - QuadraticExtensionField; - -//TODO: Check this should be for complex and not base field -impl HasQuadraticNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^2 - i - 2 - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::from(2), - FieldElement::::one(), - ])) - } -} -*/ - -/* -pub type Mersenne31ComplexCubicExtensionField = - CubicExtensionField; - -impl HasCubicNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^3 - 5*i - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::zero(), - FieldElement::::from(5), - ])) - } -} -*/ - #[cfg(test)] mod tests { - use core::{num::FpCategory, ops::Neg}; + use core::ops::Neg; use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; @@ -200,6 +141,7 @@ mod tests { #[test] fn add_real_one_plus_one_is_two() { + println!("{:?}", Fp2E::from(2)); assert_eq!(Fp2E::one() + Fp2E::one(), Fp2E::from(2)) } @@ -223,7 +165,6 @@ mod tests { #[test] fn add_complex_one_plus_one_two() { - //Manually declare the complex part to one let one_i = Fp2E::new([FpE::zero(), FpE::one()]); let two_i = Fp2E::new([FpE::zero(), FpE::from(2)]); assert_eq!(&one_i + &one_i, two_i) @@ -351,9 +292,49 @@ mod tests { } #[test] - fn mul_base_field_with_degree_2_extension() { - let a = FpE::from(3); - let b = Fp2E::new([FpE::from(2), FpE::from(4)]); - assert_eq!(a * b, Fp2E::new([FpE::from(6), FpE::from(12)])) + fn test_base_field_2_extension_add() { + let a = Fee::new([FE::from(0), FE::from(3)]); + let b = Fee::new([-FE::from(2), FE::from(8)]); + let expected_result = Fee::new([FE::from(0) - FE::from(2), FE::from(3) + FE::from(8)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_base_field_2_extension_sub() { + let a = Fee::new([FE::from(0), FE::from(3)]); + let b = Fee::new([-FE::from(2), FE::from(8)]); + let expected_result = Fee::new([FE::from(0) + FE::from(2), FE::from(3) - FE::from(8)]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_degree_2_extension_mul() { + let a = Fee::new([FE::from(12), FE::from(5)]); + let b = Fee::new([-FE::from(4), FE::from(2)]); + let expected_result = Fee::new([ + FE::from(12) * (-FE::from(4)) + + FE::from(5) * FE::from(2) * Babybear31PrimeField::residue(), + FE::from(12) * FE::from(2) + FE::from(5) * (-FE::from(4)), + ]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_degree_2_extension_inv() { + let a = Fee::new([FE::from(12), FE::from(5)]); + let inv_norm = (FE::from(12).pow(2_u64) + - Babybear31PrimeField::residue() * FE::from(5).pow(2_u64)) + .inv() + .unwrap(); + let expected_result = Fee::new([FE::from(12) * &inv_norm, -&FE::from(5) * inv_norm]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_degree_2_extension_div() { + let a = Fee::new([FE::from(12), FE::from(5)]); + let b = Fee::new([-FE::from(4), FE::from(2)]); + let expected_result = &a * b.inv().unwrap(); + assert_eq!(a / b, expected_result); } } From ae1446a8e773e3fddbe544251879a1ca3d66905b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 24 Sep 2024 12:00:29 -0300 Subject: [PATCH 17/93] add fp2 tests --- math/src/field/fields/mersenne31/extension.rs | 203 +++++++++++++----- 1 file changed, 146 insertions(+), 57 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index c2cfdd9ad..549b5f9c8 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -1,14 +1,8 @@ -use crate::{ - elliptic_curve::short_weierstrass::curves::bls12_381::field_extension::LevelTwoResidue, - field::{ - element::FieldElement, - errors::FieldError, - extensions::{ - cubic::{CubicExtensionField, HasCubicNonResidue}, - quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, - }, - traits::{IsField, IsSubFieldOf}, - }, +use crate::field::{ + element::FieldElement, + errors::FieldError, + extensions::quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, + traits::{IsField, IsSubFieldOf}, }; use super::field::Mersenne31Field; @@ -336,12 +330,12 @@ mod tests { use super::*; + type FpE = FieldElement; type Fp2E = FieldElement; type Fp4E = FieldElement; #[test] fn add_real_one_plus_one_is_two() { - println!("{:?}", Fp2E::from(2)); assert_eq!(Fp2E::one() + Fp2E::one(), Fp2E::from(2)) } @@ -492,49 +486,144 @@ mod tests { } #[test] - fn test_base_field_2_extension_add() { - let a = Fee::new([FE::from(0), FE::from(3)]); - let b = Fee::new([-FE::from(2), FE::from(8)]); - let expected_result = Fee::new([FE::from(0) - FE::from(2), FE::from(3) + FE::from(8)]); + fn test_fp2_add() { + let a = Fp2E::new([FpE::from(0), FpE::from(3)]); + let b = Fp2E::new([-FpE::from(2), FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(0) - FpE::from(2), FpE::from(3) + FpE::from(8)]); assert_eq!(a + b, expected_result); } #[test] - fn test_base_field_2_extension_sub() { - let a = Fee::new([FE::from(0), FE::from(3)]); - let b = Fee::new([-FE::from(2), FE::from(8)]); - let expected_result = Fee::new([FE::from(0) + FE::from(2), FE::from(3) - FE::from(8)]); + fn test_fp2_add_2() { + let a = Fp2E::new([FpE::from(2), FpE::from(4)]); + let b = Fp2E::new([-FpE::from(2), -FpE::from(4)]); + let expected_result = Fp2E::new([FpE::from(2) - FpE::from(2), FpE::from(4) - FpE::from(4)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_fp2_add_3() { + let a = Fp2E::new([FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER), FpE::from(1)]); + let b = Fp2E::new([FpE::from(1), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::from(1), FpE::from(1)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_fp2_sub() { + let a = Fp2E::new([FpE::from(0), FpE::from(3)]); + let b = Fp2E::new([-FpE::from(2), FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(0) + FpE::from(2), FpE::from(3) - FpE::from(8)]); assert_eq!(a - b, expected_result); } #[test] - fn test_degree_2_extension_mul() { - let a = Fee::new([FE::from(12), FE::from(5)]); - let b = Fee::new([-FE::from(4), FE::from(2)]); - let expected_result = Fee::new([ - FE::from(12) * (-FE::from(4)) - + FE::from(5) * FE::from(2) * Babybear31PrimeField::residue(), - FE::from(12) * FE::from(2) + FE::from(5) * (-FE::from(4)), - ]); + fn test_fp2_sub_2() { + let a = Fp2E::new([FpE::zero(), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let b = Fp2E::new([FpE::one(), -FpE::one()]); + let expected_result = + Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)), FpE::one()]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_fp2_sub_3() { + let a = Fp2E::new([FpE::from(5), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let b = Fp2E::new([FpE::from(5), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_fp2_mul() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = Fp2E::new([-FpE::from(4), FpE::from(2)]); + let expected_result = Fp2E::new([-FpE::from(58), FpE::new(4)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_2() { + let a = Fp2E::new([FpE::one(), FpE::zero()]); + let b = Fp2E::new([FpE::from(12), -FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(12), -FpE::new(8)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_3() { + let a = Fp2E::new([FpE::zero(), FpE::zero()]); + let b = Fp2E::new([FpE::from(2), FpE::from(7)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_4() { + let a = Fp2E::new([FpE::from(2), FpE::from(7)]); + let b = Fp2E::new([FpE::zero(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); assert_eq!(a * b, expected_result); } #[test] - fn test_degree_2_extension_inv() { - let a = Fee::new([FE::from(12), FE::from(5)]); - let inv_norm = (FE::from(12).pow(2_u64) - - Babybear31PrimeField::residue() * FE::from(5).pow(2_u64)) - .inv() - .unwrap(); - let expected_result = Fee::new([FE::from(12) * &inv_norm, -&FE::from(5) * inv_norm]); + fn test_fp2_mul_5() { + let a = Fp2E::new([FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER), FpE::one()]); + let b = Fp2E::new([FpE::from(2), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_inv() { + let a = Fp2E::new([FpE::one(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::one(), FpE::zero()]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_fp2_inv_2() { + let a = Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)), FpE::one()]); + let expected_result = Fp2E::new([FpE::from(1073741823), FpE::from(1073741823)]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_fp2_inv_3() { + let a = Fp2E::new([FpE::from(2063384121), FpE::from(1232183486)]); + let expected_result = Fp2E::new([FpE::from(1244288232), FpE::from(1321511038)]); assert_eq!(a.inv().unwrap(), expected_result); } #[test] - fn test_degree_2_extension_div() { - let a = Fee::new([FE::from(12), FE::from(5)]); - let b = Fee::new([-FE::from(4), FE::from(2)]); - let expected_result = &a * b.inv().unwrap(); + fn test_fp2_mul_inv() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = a.inv().unwrap(); + let expected_result = Fp2E::new([FpE::one(), FpE::zero()]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_div() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = Fp2E::new([FpE::from(4), FpE::from(2)]); + let expected_result = Fp2E::new([FpE::from(644245097), FpE::from(1288490188)]); + assert_eq!(a / b, expected_result); + } + + #[test] + fn test_fp2_div_2() { + let a = Fp2E::new([FpE::from(4), FpE::from(7)]); + let b = Fp2E::new([FpE::one(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::from(4), FpE::from(7)]); + assert_eq!(a / b, expected_result); + } + + #[test] + fn test_fp2_div_3() { + let a = Fp2E::new([FpE::zero(), FpE::zero()]); + let b = Fp2E::new([FpE::from(3), FpE::from(12)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); assert_eq!(a / b, expected_result); } @@ -648,22 +737,22 @@ mod tests { assert_eq!(a.to_extension::(), a_extension); } - #[test] - fn add_fp_and_fp4() { - let a = FpE::from(3); - let a_extension = Fp4E::from(3); - let b = Fp4E::from(2); - assert_eq!(a + &b, a_extension + b); - } - - #[test] - fn mul_fp_by_fp4() { - let a = FpE::from(30000000000); - let a_extension = a.clone().to_extension::(); - let b = Fp4E::new([ - Fp2E::new([FpE::from(1), FpE::from(2)]), - Fp2E::new([FpE::from(3), FpE::from(4)]), - ]); - assert_eq!(a * &b, a_extension * b); - } + // #[test] + // fn add_fp_and_fp4() { + // let a = FpE::from(3); + // let a_extension = Fp4E::from(3); + // let b = Fp4E::from(2); + // assert_eq!(a + &b, a_extension + b); + // } + + // #[test] + // fn mul_fp_by_fp4() { + // let a = FpE::from(30000000000); + // let a_extension = a.clone().to_extension::(); + // let b = Fp4E::new([ + // Fp2E::new([FpE::from(1), FpE::from(2)]), + // Fp2E::new([FpE::from(3), FpE::from(4)]), + // ]); + // assert_eq!(a * &b, a_extension * b); + // } } From fa12fc5bfa8d493515bfc0dd7d2c2444c5342227 Mon Sep 17 00:00:00 2001 From: Nicole Date: Tue, 24 Sep 2024 12:36:33 -0300 Subject: [PATCH 18/93] add 2 * a^2 - 1 function --- math/src/field/fields/mersenne31/field.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 88015e0a7..84768c82f 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -56,6 +56,12 @@ impl Mersenne31Field { (0..order).for_each(|_| res = Self::square(&res)); res } + + /// TODO: Ask how should we implement this function. + /// Computes 2a^2 - 1 + pub fn two_square_minus_one(a: &u32) -> u32 { + Self::from_u64(((u64::from(*a) * u64::from(*a)) << 1) - 1) + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -462,4 +468,13 @@ mod tests { let a = FE::from(1234); assert_eq!(a + a, a.double()) } + + #[test] + fn two_square_minus_one_is_correct() { + let a = FE::from(2147483650); + assert_eq!( + FE::from(&F::two_square_minus_one(&a.value())), + a.square().double() - FE::one() + ) + } } From 461485b6a04af9c16f39a666b545b1968897c5dc Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 25 Sep 2024 14:59:47 -0300 Subject: [PATCH 19/93] use karatsuba in fp4 mul version 1 --- math/src/field/fields/mersenne31/extension.rs | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 549b5f9c8..4f185701e 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -196,13 +196,47 @@ impl IsField for Degree4ExtensionField { [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] */ - // VERSION 2 (paper): + // VERSION 1 using Karatsuba: + // a = a0 + a1 * u, b = b0 + b1 * u, where + // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc. + let [a00, a01] = a[0].value(); + let [a10, a11] = a[1].value(); + let [b00, b01] = b[0].value(); + let [b10, b11] = b[1].value(); + + let a00b00 = a00 * b00; + let a00b10 = a00 * b10; + let a01b01 = a01 * b01; + let a01b11 = a01 * b11; + let a10b00 = a10 * b00; + let a10b10 = a10 * b10; + let a10b11 = a10 * b11; + let a11b01 = a11 * b01; + let a11b10 = b10 * a11; + let a11b11 = a11 * b11; + + let c00 = a00b00 - a01b01 + a10b10.double() - a10b11 - a11b10 - (a11b11).double(); + let c01 = (a00 + a01) * (b00 + b01) - a00b00 - a01b01 + + a10b10 + + a10b11.double() + + a11b10.double() + - a11b11; + let c10 = a00b10 - a01b11 + a10b00 - a11b01; + let c11 = (a00 + a01) * (b10 + b11) - a00b10 - a01b11 + (a10 + a11) * (b00 + b01) + - a10b00 + - a11b01; + + [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] + + /* + // VERSION 2 (paper, karatsuba): let a0b0 = &a[0] * &b[0]; let a1b1 = &a[1] * &b[1]; [ &a0b0 + mul_fp2_by_nonresidue(&a1b1), (&a[0] + &a[1]) * (&b[0] + &b[1]) - a0b0 - a1b1, ] + */ } fn square(a: &Self::BaseType) -> Self::BaseType { From d4e3f408b4f5f9e730931b888dba6a3689c66779 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 25 Sep 2024 16:00:01 -0300 Subject: [PATCH 20/93] clean up --- math/benches/fields/mersenne31.rs | 75 +----- math/src/field/fields/mersenne31/extension.rs | 223 +++--------------- math/src/field/fields/mersenne31/field.rs | 76 ++---- 3 files changed, 65 insertions(+), 309 deletions(-) diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index d6b2f90ae..de0a95e03 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -1,14 +1,11 @@ use std::hint::black_box; use criterion::Criterion; -use lambdaworks_math::{ - elliptic_curve::edwards::curves::bandersnatch::field, - field::{ - element::FieldElement, - fields::mersenne31::{ - extension::{Degree2ExtensionField, Degree4ExtensionField, Degree4ExtensionFieldV2}, - field::Mersenne31Field, - }, +use lambdaworks_math::field::{ + element::FieldElement, + fields::mersenne31::{ + extension::{Degree2ExtensionField, Degree4ExtensionField}, + field::Mersenne31Field, }, }; use rand::random; @@ -46,36 +43,8 @@ pub fn rand_fp4e(num: usize) -> Vec<(Fp4E, Fp4E)> { result } -pub fn rand_fp4e_v2( - num: usize, -) -> Vec<( - FieldElement, - FieldElement, -)> { - let mut result = Vec::with_capacity(num); - for _ in 0..result.capacity() { - result.push(( - FieldElement::::new([ - Fp2E::new([F::new(random()), F::new(random())]), - Fp2E::new([F::new(random()), F::new(random())]), - ]), - FieldElement::::new([ - Fp2E::new([F::new(random()), F::new(random())]), - Fp2E::new([F::new(random()), F::new(random())]), - ]), - )); - } - result -} - pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { let input: Vec> = [1000000].into_iter().map(rand_fp4e).collect::>(); - let input_v2: Vec< - Vec<( - FieldElement, - FieldElement, - )>, - > = [1000000].into_iter().map(rand_fp4e_v2).collect::>(); let mut group = c.benchmark_group("Mersenne31 Fp4 operations"); @@ -89,16 +58,6 @@ pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { }); } - for i in input_v2.clone().into_iter() { - group.bench_with_input(format!("Mul of Fp4 V2 {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, y) in i { - black_box(black_box(x) * black_box(y)); - } - }); - }); - } - for i in input.clone().into_iter() { group.bench_with_input(format!("Square of Fp4 {:?}", &i.len()), &i, |bench, i| { bench.iter(|| { @@ -109,20 +68,6 @@ pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { }); } - for i in input_v2.clone().into_iter() { - group.bench_with_input( - format!("Square of Fp4 V2 {:?}", &i.len()), - &i, - |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).square()); - } - }); - }, - ); - } - for i in input.clone().into_iter() { group.bench_with_input(format!("Inv of Fp4 {:?}", &i.len()), &i, |bench, i| { bench.iter(|| { @@ -132,16 +77,6 @@ pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { }); }); } - - for i in input_v2.clone().into_iter() { - group.bench_with_input(format!("Inv of Fp4 V2 {:?}", &i.len()), &i, |bench, i| { - bench.iter(|| { - for (x, _) in i { - black_box(black_box(x).inv().unwrap()); - } - }); - }); - } } pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 4f185701e..a05b58b97 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -1,7 +1,6 @@ use crate::field::{ element::FieldElement, errors::FieldError, - extensions::quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, traits::{IsField, IsSubFieldOf}, }; @@ -14,6 +13,15 @@ type FpE = FieldElement; #[derive(Clone, Debug)] pub struct Degree2ExtensionField; +impl Degree2ExtensionField { + pub fn mul_fp2_by_nonresidue(a: &Fp2E) -> Fp2E { + Fp2E::new([ + a.value()[0].double() - a.value()[1], + a.value()[1].double() + a.value()[0], + ]) + } +} + impl IsField for Degree2ExtensionField { //Elements represents a[0] = real, a[1] = imaginary type BaseType = [FpE; 2]; @@ -132,29 +140,6 @@ impl IsSubFieldOf for Mersenne31Field { type Fp2E = FieldElement; -/// Extension of degree 4 defined with lambdaworks quadratic extension to test the correctness of Degree4ExtensionField -#[derive(Debug, Clone)] -pub struct Mersenne31LevelTwoResidue; -impl HasQuadraticNonResidue for Mersenne31LevelTwoResidue { - fn residue() -> Fp2E { - Fp2E::new([FpE::from(2), FpE::one()]) - } -} -pub type Degree4ExtensionFieldV2 = - QuadraticExtensionField; - -/// I = 0 + 1 * i -pub const I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(0), FpE::const_from_raw(1)]); - -/// TWO_PLUS_I = 2 + 1 is the non-residue of Fp2 used for the Fp4 extension. -pub const TWO_PLUS_I: Fp2E = Fp2E::const_from_raw([FpE::const_from_raw(2), FpE::const_from_raw(1)]); - -pub fn mul_fp2_by_nonresidue(a: &Fp2E) -> Fp2E { - Fp2E::new([ - a.value()[0].double() - a.value()[1], - &a.value()[1].double() + &a.value()[0], - ]) -} #[derive(Clone, Debug)] pub struct Degree4ExtensionField; @@ -174,96 +159,27 @@ impl IsField for Degree4ExtensionField { } fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - /* - // VERSION 1 (distribution by hand): - // a = a0 + a1 * u, b = b0 + b1 * u, where - // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc. - let [a00, a01] = a[0].value(); - let [a10, a11] = a[1].value(); - let [b00, b01] = b[0].value(); - let [b10, b11] = b[1].value(); - - let a10b10 = a10 * b10; - let a10b11 = a10 * b11; - let a11b10 = b10 * a11; - let a11b11 = a11 * b11; - - let c00 = a00 * b00 - a01 * b01 + a10b10.double() - a10b11 - a11b10 - (a11b11).double(); - let c01 = a00 * b01 + a01 * b00 + a10b10 + a10b11.double() + a11b10.double() - a11b11; - let c10 = a00 * b10 - a01 * b11 + a10 * b00 - b01 * a11; - let c11 = a00 * b11 + a01 * b10 + a10 * b01 + a11 * b00; - - [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] - */ - - // VERSION 1 using Karatsuba: - // a = a0 + a1 * u, b = b0 + b1 * u, where - // a0 = a00 + a01 * i, a1 = a11 + a11 * i, etc. - let [a00, a01] = a[0].value(); - let [a10, a11] = a[1].value(); - let [b00, b01] = b[0].value(); - let [b10, b11] = b[1].value(); - - let a00b00 = a00 * b00; - let a00b10 = a00 * b10; - let a01b01 = a01 * b01; - let a01b11 = a01 * b11; - let a10b00 = a10 * b00; - let a10b10 = a10 * b10; - let a10b11 = a10 * b11; - let a11b01 = a11 * b01; - let a11b10 = b10 * a11; - let a11b11 = a11 * b11; - - let c00 = a00b00 - a01b01 + a10b10.double() - a10b11 - a11b10 - (a11b11).double(); - let c01 = (a00 + a01) * (b00 + b01) - a00b00 - a01b01 - + a10b10 - + a10b11.double() - + a11b10.double() - - a11b11; - let c10 = a00b10 - a01b11 + a10b00 - a11b01; - let c11 = (a00 + a01) * (b10 + b11) - a00b10 - a01b11 + (a10 + a11) * (b00 + b01) - - a10b00 - - a11b01; - - [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] - - /* - // VERSION 2 (paper, karatsuba): + // Algorithm from: https://github.com/ingonyama-zk/papers/blob/main/Mersenne31_polynomial_arithmetic.pdf (page 5): let a0b0 = &a[0] * &b[0]; let a1b1 = &a[1] * &b[1]; [ - &a0b0 + mul_fp2_by_nonresidue(&a1b1), + &a0b0 + Degree2ExtensionField::mul_fp2_by_nonresidue(&a1b1), (&a[0] + &a[1]) * (&b[0] + &b[1]) - a0b0 - a1b1, ] - */ } fn square(a: &Self::BaseType) -> Self::BaseType { - // a = a0 + a1 * u, where - // a0 = a00 + a01 * i and a1 = a11 + a11 * i - let [a00, a01] = a[0].value(); - let [a10, a11] = a[1].value(); - - let a10a10 = a10 * a10; - let a10a11 = a10 * a11; - let a11a11 = a11 * a11; - - let c00 = a00 * a00 - a01 * a01 + a10a10.double() - a10a11.double() - (a11a11).double(); - let c01 = (a00 * a01).double() + a10a10 + a10a11.double().double() - a11a11; - let c10 = (a00 * a10).double() - (a01 * a11).double(); - let c11 = (a00 * a11).double() + (a01 * a10).double(); - - [Fp2E::new([c00, c01]), Fp2E::new([c10, c11])] + let a0_square = &a[0].square(); + let a1_square = &a[1].square(); + [ + a0_square + Degree2ExtensionField::mul_fp2_by_nonresidue(&a1_square), + (&a[0] + &a[1]).square() - a0_square - a1_square, + ] } fn inv(a: &Self::BaseType) -> Result { - // VERSION 1: - // let a1_square = a[1].square(); - // let inv_norm = (a[0].square() - a1_square.double() - a1_square * I).inv()?; - - // VERSION 2: - let inv_norm = (a[0].square() - mul_fp2_by_nonresidue(&a[1].square())).inv()?; + let inv_norm = + (a[0].square() - Degree2ExtensionField::mul_fp2_by_nonresidue(&a[1].square())).inv()?; Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) } @@ -661,56 +577,6 @@ mod tests { assert_eq!(a / b, expected_result); } - #[test] - fn mul_fp4_is_correct() { - let a = Fp4E::new([ - Fp2E::new([FpE::from(2), FpE::from(3)]), - Fp2E::new([FpE::from(4), FpE::from(5)]), - ]); - - let b = Fp4E::new([ - Fp2E::new([FpE::from(6), FpE::from(7)]), - Fp2E::new([FpE::from(8), FpE::from(9)]), - ]); - - let a2 = FieldElement::::new([ - Fp2E::new([FpE::from(2), FpE::from(3)]), - Fp2E::new([FpE::from(4), FpE::from(5)]), - ]); - - let b2 = FieldElement::::new([ - Fp2E::new([FpE::from(6), FpE::from(7)]), - Fp2E::new([FpE::from(8), FpE::from(9)]), - ]); - - assert_eq!((&a * &b).value(), (a2 * b2).value()) - } - - #[test] - fn mul_fp4_is_correct_2() { - let a = Fp4E::new([ - Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), - Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), - ]); - - let b = Fp4E::new([ - Fp2E::new([FpE::from(6), FpE::from(7)]), - Fp2E::new([FpE::from(8), FpE::from(9)]), - ]); - - let a2 = FieldElement::::new([ - Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), - Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), - ]); - - let b2 = FieldElement::::new([ - Fp2E::new([FpE::from(6), FpE::from(7)]), - Fp2E::new([FpE::from(8), FpE::from(9)]), - ]); - - assert_eq!((&a * &b).value(), (a2 * b2).value()) - } - #[test] fn mul_fp4_by_zero_is_zero() { let a = Fp4E::new([ @@ -729,21 +595,6 @@ mod tests { assert_eq!(a, a.clone() * Fp4E::one()) } - #[test] - fn square_fp4_is_correct() { - let a = Fp4E::new([ - Fp2E::new([FpE::from(2), FpE::from(3)]), - Fp2E::new([FpE::from(4), FpE::from(5)]), - ]); - - let a2 = FieldElement::::new([ - Fp2E::new([FpE::from(2), FpE::from(3)]), - Fp2E::new([FpE::from(4), FpE::from(5)]), - ]); - - assert_eq!(a.square().value(), a2.square().value()) - } - #[test] fn square_fp4_equals_mul_two_times() { let a = Fp4E::new([ @@ -771,22 +622,22 @@ mod tests { assert_eq!(a.to_extension::(), a_extension); } - // #[test] - // fn add_fp_and_fp4() { - // let a = FpE::from(3); - // let a_extension = Fp4E::from(3); - // let b = Fp4E::from(2); - // assert_eq!(a + &b, a_extension + b); - // } - - // #[test] - // fn mul_fp_by_fp4() { - // let a = FpE::from(30000000000); - // let a_extension = a.clone().to_extension::(); - // let b = Fp4E::new([ - // Fp2E::new([FpE::from(1), FpE::from(2)]), - // Fp2E::new([FpE::from(3), FpE::from(4)]), - // ]); - // assert_eq!(a * &b, a_extension * b); - // } + #[test] + fn add_fp_and_fp4() { + let a = FpE::from(3); + let a_extension = Fp4E::from(3); + let b = Fp4E::from(2); + assert_eq!(a + &b, a_extension + b); + } + + #[test] + fn mul_fp_by_fp4() { + let a = FpE::from(30000000000); + let a_extension = a.clone().to_extension::(); + let b = Fp4E::new([ + Fp2E::new([FpE::from(1), FpE::from(2)]), + Fp2E::new([FpE::from(3), FpE::from(4)]), + ]); + assert_eq!(a * &b, a_extension * b); + } } diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 84768c82f..19d6d5ad1 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -100,51 +100,29 @@ impl IsField for Mersenne31Field { if *x == Self::zero() || *x == MERSENNE_31_PRIME_FIELD_ORDER { return Err(FieldError::InvZeroError); } - // OLD VERSION: - let p101 = Self::mul(&Self::pow_2(x, 2), x); - let p1111 = Self::mul(&Self::square(&p101), &p101); - let p11111111 = Self::mul(&Self::pow_2(&p1111, 4u32), &p1111); - let p111111110000 = Self::pow_2(&p11111111, 4u32); - let p111111111111 = Self::mul(&p111111110000, &p1111); - let p1111111111111111 = Self::mul(&Self::pow_2(&p111111110000, 4u32), &p11111111); - let p1111111111111111111111111111 = - Self::mul(&Self::pow_2(&p1111111111111111, 12u32), &p111111111111); - let p1111111111111111111111111111101 = - Self::mul(&Self::pow_2(&p1111111111111111111111111111, 3u32), &p101); - Ok(p1111111111111111111111111111101) - - // // OLD VERSION: - // let t0 = sqn(*x, 2) * x; - // let t1 = t0 * t0 * t0; - // let t2 = sqn(t1, 3) * t0; - // let t3 = t2 * t2 * t0; - // let t4 = sqn(t3, 8) * t3; - // let t5 = sqn(t4, 8) * t3; - // Ok(sqn(t5, 7) * t2) - - // NEW VERSION: - // let mut a: u32 = 1; - // let mut b: u32 = 0; - // let mut y: u32 = x.clone(); - // let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; - // let q: u32 = 31; - // let mut e: u32; - // let mut temp2: u32; - - // loop { - // e = y.trailing_zeros(); - // y >>= e; - // a = Self::mul_power_two(a, q - e); - // if y == 1 { - // return Ok(a); - // }; - // temp2 = a.wrapping_add(b); - // b = a; - // a = temp2; - // temp2 = y.wrapping_add(z); - // z = y; - // y = temp2; - // } + // Algorithm from: https://github.com/ingonyama-zk/papers/blob/main/Mersenne31_polynomial_arithmetic.pdf (page 3). + let mut a: u32 = 1; + let mut b: u32 = 0; + let mut y: u32 = x.clone(); + let mut z: u32 = MERSENNE_31_PRIME_FIELD_ORDER; + let q: u32 = 31; + let mut e: u32; + let mut temp: u32; + + loop { + e = y.trailing_zeros(); + y >>= e; + a = Self::mul_power_two(a, q - e); + if y == 1 { + return Ok(a); + }; + temp = a.wrapping_add(b); + b = a; + a = temp; + temp = y.wrapping_add(z); + z = y; + y = temp; + } } /// Returns the division of `a` and `b`. @@ -183,14 +161,6 @@ impl IsField for Mersenne31Field { } } -/// Computes `a^(2*n)`. -pub fn sqn(mut a: u32, n: usize) -> u32 { - for _ in 0..n { - a = Mersenne31Field::mul(&a, &a); - } - a -} - impl IsPrimeField for Mersenne31Field { type RepresentativeType = u32; From f0437c73a481ff53bd3fb9751e58e065f7bc1b5f Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 25 Sep 2024 17:09:58 -0300 Subject: [PATCH 21/93] fix Fp as subfield of Fp2. Tests Fp plus Fp4 is now correct --- math/src/field/fields/mersenne31/extension.rs | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index a05b58b97..02ca226ee 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -101,14 +101,14 @@ impl IsSubFieldOf for Mersenne31Field { a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - [FpE::from(a) + b[0], FpE::from(a) + b[1]] + [FpE::from(a) + b[0], b[1]] } fn sub( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - [FpE::from(a) - b[0], FpE::from(a) - b[1]] + [FpE::from(a) - b[0], -b[1]] } fn mul( @@ -209,45 +209,35 @@ impl IsField for Degree4ExtensionField { } impl IsSubFieldOf for Mersenne31Field { - fn mul( + fn add( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FpE::from(a) * &b[0]; - let c1 = FpE::from(a) * &b[1]; - [c0, c1] + [FpE::from(a) + &b[0], b[1].clone()] } - fn add( + fn sub( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FieldElement::from_raw(>::add( - a, - b[0].value(), - )); - let c1 = FieldElement::from_raw(*b[1].value()); - [c0, c1] + [FpE::from(a) - &b[0], -&b[1]] } - fn div( + fn mul( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let b_inv = Degree4ExtensionField::inv(b).unwrap(); - >::mul(a, &b_inv) + let c0 = FpE::from(a) * &b[0]; + let c1 = FpE::from(a) * &b[1]; + [c0, c1] } - fn sub( + fn div( a: &Self::BaseType, b: &::BaseType, ) -> ::BaseType { - let c0 = FieldElement::from_raw(>::sub( - a, - b[0].value(), - )); - let c1 = FieldElement::from_raw(::neg(b[1].value())); - [c0, c1] + let b_inv = Degree4ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) } fn embed(a: Self::BaseType) -> ::BaseType { From 2c5a3014918c0b56fa79ff92e3a5c74b8307994b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 11:23:52 -0300 Subject: [PATCH 22/93] fix inv --- math/src/field/fields/mersenne31/field.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 19d6d5ad1..4e2bd4914 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -45,10 +45,9 @@ impl Mersenne31Field { /// Computes a * 2^k, with 0 < k < 31 pub fn mul_power_two(a: u32, k: u32) -> u32 { - // If a uses 32 bits, then a * 2^k uses 32 + k bits. - let msb = (a & (u32::MAX << 32 - k)) >> (32 - k - 1); // The k+1 msb. - let lsb = (a & (u32::MAX >> k)) << k; // The 31-k lsb with k zeros. - lsb + msb + let msb = (a & (u32::MAX << 31 - k)) >> (31 - k); // The k + 1 msb corridos con 31 - k ceros a la izq. + let lsb = (a & (u32::MAX >> k + 1)) << k; // The 31 - k lsb with k zeros a la derecha. + Self::weak_reduce(msb + lsb) } pub fn pow_2(a: &u32, order: u32) -> u32 { @@ -111,8 +110,10 @@ impl IsField for Mersenne31Field { loop { e = y.trailing_zeros(); - y >>= e; - a = Self::mul_power_two(a, q - e); + if e != 0 { + y >>= e; + a = Self::mul_power_two(a, q - e) + } if y == 1 { return Ok(a); }; @@ -447,4 +448,10 @@ mod tests { a.square().double() - FE::one() ) } + + #[test] + fn mul_by_inv() { + let x = 3476715743_u32; + assert_eq!(FE::from(&x).inv().unwrap() * FE::from(&x), FE::one()); + } } From 4a42dbc524a9b2f00d602bfb8ab4dc512f296a93 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 11:27:15 -0300 Subject: [PATCH 23/93] fix comments --- math/src/field/fields/mersenne31/field.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 4e2bd4914..38721210a 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -45,8 +45,8 @@ impl Mersenne31Field { /// Computes a * 2^k, with 0 < k < 31 pub fn mul_power_two(a: u32, k: u32) -> u32 { - let msb = (a & (u32::MAX << 31 - k)) >> (31 - k); // The k + 1 msb corridos con 31 - k ceros a la izq. - let lsb = (a & (u32::MAX >> k + 1)) << k; // The 31 - k lsb with k zeros a la derecha. + let msb = (a & (u32::MAX << 31 - k)) >> (31 - k); // The k + 1 msf shifted right . + let lsb = (a & (u32::MAX >> k + 1)) << k; // The 31 - k lsb shifted left. Self::weak_reduce(msb + lsb) } From 21d09c68ccfc15aea858f93610b3488d2866972e Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 11:30:43 -0300 Subject: [PATCH 24/93] create crate --- provers/circle/Cargo.toml | 8 ++++++++ provers/circle/src/main.rs | 3 +++ 2 files changed, 11 insertions(+) create mode 100644 provers/circle/Cargo.toml create mode 100644 provers/circle/src/main.rs diff --git a/provers/circle/Cargo.toml b/provers/circle/Cargo.toml new file mode 100644 index 000000000..0509c8ec5 --- /dev/null +++ b/provers/circle/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "circle" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] diff --git a/provers/circle/src/main.rs b/provers/circle/src/main.rs new file mode 100644 index 000000000..e7a11a969 --- /dev/null +++ b/provers/circle/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} From f7efd9016ecf90b1a42bed7005e13246d42dad2e Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 11:37:37 -0300 Subject: [PATCH 25/93] Revert "create crate" This reverts commit 21d09c68ccfc15aea858f93610b3488d2866972e. --- provers/circle/Cargo.toml | 8 -------- provers/circle/src/main.rs | 3 --- 2 files changed, 11 deletions(-) delete mode 100644 provers/circle/Cargo.toml delete mode 100644 provers/circle/src/main.rs diff --git a/provers/circle/Cargo.toml b/provers/circle/Cargo.toml deleted file mode 100644 index 0509c8ec5..000000000 --- a/provers/circle/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "circle" -version.workspace = true -edition.workspace = true -license.workspace = true -repository.workspace = true - -[dependencies] diff --git a/provers/circle/src/main.rs b/provers/circle/src/main.rs deleted file mode 100644 index e7a11a969..000000000 --- a/provers/circle/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} From eddd9be66f05e07ea43e19056ec3c0f38a8d1b3b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 17:39:11 -0300 Subject: [PATCH 26/93] add circle point implementation --- math/src/circle/errors.rs | 5 + math/src/circle/mod.rs | 2 + math/src/circle/point.rs | 222 ++++++++++++++++++ math/src/field/fields/mersenne31/extension.rs | 9 +- math/src/lib.rs | 1 + 5 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 math/src/circle/errors.rs create mode 100644 math/src/circle/mod.rs create mode 100644 math/src/circle/point.rs diff --git a/math/src/circle/errors.rs b/math/src/circle/errors.rs new file mode 100644 index 000000000..d2f569d19 --- /dev/null +++ b/math/src/circle/errors.rs @@ -0,0 +1,5 @@ + +#[derive(Debug)] +pub enum CircleError { + InvalidValue, +} diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs new file mode 100644 index 000000000..7aaf74437 --- /dev/null +++ b/math/src/circle/mod.rs @@ -0,0 +1,2 @@ +pub mod point; +pub mod errors; \ No newline at end of file diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs new file mode 100644 index 000000000..57698b1f7 --- /dev/null +++ b/math/src/circle/point.rs @@ -0,0 +1,222 @@ +use crate::field::traits::IsField; +use crate::field::{element::FieldElement, fields::mersenne31::{field::Mersenne31Field, extension::{Degree2ExtensionField, Degree4ExtensionField}}}; +use super::errors::CircleError; +use std::cmp::PartialEq; +use std::ops::Add; +use std::process::Output; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct CirclePoint { + pub x: FieldElement, + pub y: FieldElement, +} + +pub trait HasCircleParams { + type FE; + + fn circle_generator() -> (FieldElement, FieldElement); +} + + +impl HasCircleParams for Mersenne31Field { + type FE = FieldElement; + + // This could be a constant instead of a function + fn circle_generator() -> (Self::FE, Self::FE){ + ( + Self::FE::from(&2), + Self::FE::from(&1268011823) + ) + } +} + +impl HasCircleParams for Degree4ExtensionField { + type FE = FieldElement; + + // This could be a constant instead of a function + fn circle_generator() -> (FieldElement, FieldElement){ + ( + Degree4ExtensionField::from_coeffcients( + FieldElement::::one(), + FieldElement::::zero(), + FieldElement::::from(&478637715), + FieldElement::::from(&513582971), + ), + + Degree4ExtensionField::from_coeffcients( + FieldElement::::from(992285211), + FieldElement::::from(649143431), + FieldElement::::from(&740191619), + FieldElement::::from(&1186584352) + + ) + ) + } +} + +impl> CirclePoint{ + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(CirclePoint { x, y }) + } else { + Err(CircleError::InvalidValue) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } + + /// Computes (a0, a1) + (b0, b1) = (a0 * b0 - a1 * b1, a0 * b1 + a1 * b0) + pub fn add(a: Self, b: Self) -> Self { + let x = &a.x * &b.x - &a.y * &b.y; + let y = a.x * b.y + a.y * b.x; + CirclePoint{ x, y } + } + + + /// Computes n * (x, y) = (x ,y) + ... + (x, y) n-times. + pub fn mul(self, mut scalar: u128) -> Self { + let mut res = Self::zero(); + let mut cur = self; + loop { + if scalar == 0 { + return res + } + if scalar & 1 == 1 { + res = res + cur.clone(); + } + cur = cur.double(); + scalar >>= 1; + } + } + + /// Computes 2(x, y) = (2x^2 - 1, 2xy). + pub fn double(self) -> Self { + Self::new( + self.x.square().double() - FieldElement::one(), + self.x.double() * self.y, + ).unwrap() + } + + /// Computes 2^n * (x, y). + pub fn repeated_double(self, n: u32) -> Self { + let mut res = self; + for _ in 0..n { + res = res.double(); + } + res + } + + /// Computes the inverse of the point. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation. + pub fn conjugate(self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn eq(a: Self, b: Self) -> bool { + a.x == b.x && a.y == b.y + } + + pub fn generator() -> Self { + CirclePoint::new( + F::circle_generator().0, + F::circle_generator().1 + ).unwrap() + } +} + +impl> PartialEq for CirclePoint { + fn eq(&self, other: &Self) -> bool { + CirclePoint::eq(self.clone(), other.clone()) + } +} + +impl> Add for CirclePoint { + type Output = CirclePoint; + fn add(self, other: Self) -> Self { + CirclePoint::add(self, other) + } +} + +#[cfg(test)] +mod tests { + use super::*; + type F = Mersenne31Field; + type FE = FieldElement; + type G = CirclePoint; + + type Fp4 = Degree4ExtensionField; + type Fp4E = FieldElement; + type G4 = CirclePoint; + + #[test] + fn create_new_valid_g_point() { + let valid_point = G::new(FE::one(), FE::zero()).unwrap(); + let expected = G { x: FE::one(), y: FE::zero() }; + assert_eq!(valid_point, expected) + } + + #[test] + fn create_new_valid_g4_point() { + let valid_point = G4::new(Fp4E::one(), Fp4E::zero()).unwrap(); + let expected = G4 { x: Fp4E::one(), y: Fp4E::zero() }; + assert_eq!(valid_point, expected) + } + + #[test] + fn create_new_invalid_circle_point() { + let invalid_point = G::new(FE::one(), FE::one()); + assert!(invalid_point.is_err()) + } + + #[test] + fn create_new_invalid_g4_circle_point() { + let invalid_point = G4::new(Fp4E::one(), Fp4E::one()); + assert!(invalid_point.is_err()) + } + + #[test] + fn zero_plus_zero_is_zero() { + let a = G::zero(); + let b = G::zero(); + assert_eq!(a + b, G::zero()) + } + + #[test] + fn generator_plus_zero_is_generator(){ + let g = G::generator(); + let zero = G::zero(); + assert_eq!(g.clone() + zero, g) + } + + #[test] + fn double_equals_mul_two() { + let g = G::generator(); + assert_eq!(g.clone().double(), G::mul(g, 2)) + } + + #[test] + fn mul_eight_equals_double_three_times(){ + let g = G::generator(); + assert_eq!(g.clone().repeated_double(3), G::mul(g, 8)) + } + + #[test] + fn generator_has_order_two_pow_31 (){ + let g = G::generator(); + let n = 31; + assert_eq!(g.repeated_double(n), G::zero()) + } + + #[test] + fn conjugation_is_inverse_operation () { + let g = G::generator(); + assert_eq!(g.clone() + g.conjugate() , G::zero()) + } +} diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs index 02ca226ee..c72712c7f 100644 --- a/math/src/field/fields/mersenne31/extension.rs +++ b/math/src/field/fields/mersenne31/extension.rs @@ -7,6 +7,8 @@ use crate::field::{ use super::field::Mersenne31Field; type FpE = FieldElement; +type Fp2E = FieldElement; +type Fp4E = FieldElement; //Note: The inverse calculation in mersenne31/plonky3 differs from the default quadratic extension so I implemented the complex extension. ////////////////// @@ -138,11 +140,16 @@ impl IsSubFieldOf for Mersenne31Field { } } -type Fp2E = FieldElement; #[derive(Clone, Debug)] pub struct Degree4ExtensionField; +impl Degree4ExtensionField { + pub fn from_coeffcients(a: FpE, b: FpE, c: FpE, d:FpE) -> Fp4E { + Fp4E::new([Fp2E::new([a, b]), Fp2E::new([c, d])]) + } +} + impl IsField for Degree4ExtensionField { type BaseType = [Fp2E; 2]; diff --git a/math/src/lib.rs b/math/src/lib.rs index 56c6e598e..1f5ae60d6 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -3,6 +3,7 @@ #[cfg(feature = "alloc")] extern crate alloc; +pub mod circle; pub mod cyclic_group; pub mod elliptic_curve; pub mod errors; From ba738c7d6eeee70f0d177bf874ff6b4aff2a13b7 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 17:55:59 -0300 Subject: [PATCH 27/93] add group order --- math/src/circle/point.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 57698b1f7..00a056cb8 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -16,6 +16,7 @@ pub trait HasCircleParams { type FE; fn circle_generator() -> (FieldElement, FieldElement); + const ORDER: u128; } @@ -29,6 +30,9 @@ impl HasCircleParams for Mersenne31Field { Self::FE::from(&1268011823) ) } + + /// ORDER = 2^31 + const ORDER: u128 = 2147483648; } impl HasCircleParams for Degree4ExtensionField { @@ -53,6 +57,9 @@ impl HasCircleParams for Degree4ExtensionField { ) ) } + + /// ORDER = (2^31 - 1)^4 - 1 + const ORDER: u128 = 21267647892944572736998860269687930880; } impl> CirclePoint{ @@ -129,6 +136,10 @@ impl> CirclePoint{ F::circle_generator().1 ).unwrap() } + + pub fn group_order() -> u128 { + F::ORDER + } } impl> PartialEq for CirclePoint { @@ -208,12 +219,18 @@ mod tests { } #[test] - fn generator_has_order_two_pow_31 (){ + fn generator_g1_has_order_two_pow_31 (){ let g = G::generator(); let n = 31; assert_eq!(g.repeated_double(n), G::zero()) } + #[test] + fn generator_g4_has_the_order_of_the_group (){ + let g = G4::generator(); + assert_eq!(g.mul(G4::group_order()), G4::zero()) + } + #[test] fn conjugation_is_inverse_operation () { let g = G::generator(); From 85d80c18c244a5bbeda0fb9e2ed52b1ce092ddee Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 26 Sep 2024 18:12:11 -0300 Subject: [PATCH 28/93] rm dependencie --- math/src/circle/point.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 00a056cb8..bea9f6329 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,5 +1,5 @@ use crate::field::traits::IsField; -use crate::field::{element::FieldElement, fields::mersenne31::{field::Mersenne31Field, extension::{Degree2ExtensionField, Degree4ExtensionField}}}; +use crate::field::{element::FieldElement, fields::mersenne31::{field::Mersenne31Field, extension::Degree4ExtensionField}}; use super::errors::CircleError; use std::cmp::PartialEq; use std::ops::Add; From 14a7bebc0b5da6f9b612d5996b445d8551fe2f26 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 30 Sep 2024 16:35:48 -0300 Subject: [PATCH 29/93] add cosets --- math/src/circle/cosets.rs | 36 ++++++++++++++++++++++++++++++++++++ math/src/circle/mod.rs | 3 ++- 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 math/src/circle/cosets.rs diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs new file mode 100644 index 000000000..9f5819181 --- /dev/null +++ b/math/src/circle/cosets.rs @@ -0,0 +1,36 @@ +use crate::circle::point::{CirclePoint, HasCircleParams}; +use crate::field::traits::IsField; +use crate::field::fields::mersenne31::field::Mersenne31Field; + +struct Coset { + // Coset: shift + where n = 2^{log_2_size}. + // Example: g_16 + , n = 8, log_2_size = 3, shift = g_16. + log_2_size: u128, + shift: CirclePoint, +} + +impl Coset { + pub fn new(log_2_size: u128, shift: CirclePoint) -> Self { + Coset{ log_2_size, shift } + } + + /// Returns the coset g_2n + + pub fn new_standard(log_2_size: u128) -> Self { + // shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}. + // We are using that g * k is a generator of the subgroup of order 2^{32 - k}, with k = log_2_size + 1. + let shift = CirclePoint::generator().mul(31 - log_2_size); + Coset{ log_2_size, shift } + } + + /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + + pub fn half_coset(coset: Self) -> Self { + Coset { log_2_size: coset.log_2_size + 1, shift: coset.shift } + } + + /// Given a coset shift + G returns the coset -shift + G. + /// Note that (g_2n + ) U (-g_2n + ) = g_2n + . + pub fn conjugate(coset: Self) -> Self { + Coset { log_2_size: coset.log_2_size, shift: coset.shift.conjugate() } + } +} + diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index 7aaf74437..830b3e60c 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -1,2 +1,3 @@ pub mod point; -pub mod errors; \ No newline at end of file +pub mod errors; +pub mod cosets; \ No newline at end of file From b56d21249b99a2a8cdb57c0e0977cb6a10acaab4 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 1 Oct 2024 11:10:18 -0300 Subject: [PATCH 30/93] add twiddle --- math/src/circle/cosets.rs | 1 - math/src/circle/twiddles.rs | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 math/src/circle/twiddles.rs diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 9f5819181..b21b7acd1 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -33,4 +33,3 @@ impl Coset { Coset { log_2_size: coset.log_2_size, shift: coset.shift.conjugate() } } } - diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs new file mode 100644 index 000000000..8ae4782a3 --- /dev/null +++ b/math/src/circle/twiddles.rs @@ -0,0 +1,22 @@ +// fn compute_twiddles(domain: CircleDomain) -> Vec> { +// assert!(domain.log_n >= 1); +// let mut pts = domain.coset0().collect_vec(); +// reverse_slice_index_bits(&mut pts); +// let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()]; +// if domain.log_n >= 2 { +// twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec()); +// for i in 0..(domain.log_n - 2) { +// let prev = twiddles.last().unwrap(); +// assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i)); +// let cur = prev +// .iter() +// .step_by(2) +// .map(|x| x.square().double() - F::one()) +// .collect_vec(); +// twiddles.push(cur); +// } +// } +// twiddles +// } + + From 58aeb0c4da46ce233b5296ebacd8ebb6cf562faf Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 1 Oct 2024 18:27:02 -0300 Subject: [PATCH 31/93] init cfft --- math/src/circle/cfft.rs | 51 ++++++++++++++++++++++++++ math/src/circle/cosets.rs | 65 +++++++++++++++++++++++++++------ math/src/circle/mod.rs | 4 ++- math/src/circle/point.rs | 20 +++++++++++ math/src/circle/twiddles.rs | 72 ++++++++++++++++++++++++++----------- 5 files changed, 181 insertions(+), 31 deletions(-) create mode 100644 math/src/circle/cfft.rs diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs new file mode 100644 index 000000000..413fe83b9 --- /dev/null +++ b/math/src/circle/cfft.rs @@ -0,0 +1,51 @@ +use crate::circle::{cosets::Coset, point::CirclePoint}; +use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; + + +pub fn cfft(input: &mut [FieldElement], twiddles: Vec>>) +{ + // divide input in groups, starting with 1, duplicating the number of groups in each stage. + let mut group_count = 1; + let mut group_size = input.len(); + + // for each group, there'll be group_size / 2 butterflies. + // a butterfly is the atomic operation of a FFT, e.g: (a, b) = (a + wb, a - wb). + // The 0.5 factor is what gives FFT its performance, it recursively halves the problem size + // (group size). + + while group_count < input.len() { + #[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm + for group in 0..group_count { + let first_in_group = group * group_size; + let first_in_next_group = first_in_group + group_size / 2; + + let w = &twiddles[group]; // a twiddle factor is used per group + + for i in first_in_group..first_in_next_group { + let wi = w[i] * &input[i + group_size / 2]; + + let y0 = &input[i] + &wi; + let y1 = &input[i] - &wi; + + input[i] = y0; + input[i + group_size / 2] = y1; + } + } + group_count *= 2; + group_size /= 2; + } +} + +// #[cfg(test)] +// mod tests { +// use crate::circle::twiddles; + +// use super::*; + +// #[test] +// fn cfft() { +// let coset = Coset::new_standard(3); +// let twiddles = +// assert_eq!(1 << coset.log_2_size, points.len()) +// } +// } diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index b21b7acd1..d01b449b9 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,30 +1,36 @@ -use crate::circle::point::{CirclePoint, HasCircleParams}; -use crate::field::traits::IsField; +use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; +use std::iter::successors; -struct Coset { + +#[derive(Debug, Clone)] +pub struct Coset { // Coset: shift + where n = 2^{log_2_size}. // Example: g_16 + , n = 8, log_2_size = 3, shift = g_16. - log_2_size: u128, - shift: CirclePoint, + pub log_2_size: u32, //TODO: Change log_2_size to u8 because log_2_size < 31. + pub shift: CirclePoint, } impl Coset { - pub fn new(log_2_size: u128, shift: CirclePoint) -> Self { + pub fn new(log_2_size: u32, shift: CirclePoint) -> Self { Coset{ log_2_size, shift } } /// Returns the coset g_2n + - pub fn new_standard(log_2_size: u128) -> Self { + pub fn new_standard(log_2_size: u32) -> Self { // shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}. - // We are using that g * k is a generator of the subgroup of order 2^{32 - k}, with k = log_2_size + 1. - let shift = CirclePoint::generator().mul(31 - log_2_size); + let shift = CirclePoint::get_generator_of_subgroup((log_2_size as u32) + 1); Coset{ log_2_size, shift } } + + /// Returns g_n, the generator of the subgroup of order n = 2^log_2_size. + pub fn get_generator(&self) -> CirclePoint { + CirclePoint::generator().repeated_double(31 - self.log_2_size as u32) + } /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + pub fn half_coset(coset: Self) -> Self { - Coset { log_2_size: coset.log_2_size + 1, shift: coset.shift } + Coset { log_2_size: coset.log_2_size - 1, shift: coset.shift } } /// Given a coset shift + G returns the coset -shift + G. @@ -32,4 +38,43 @@ impl Coset { pub fn conjugate(coset: Self) -> Self { Coset { log_2_size: coset.log_2_size, shift: coset.shift.conjugate() } } + + /// Returns the vector of shift + g for every g in . + /// where g = i * g_n for i = 0, ..., n-1. + pub fn get_coset_points(coset: &Self) -> Vec> { + // g_n the generator of the subgroup of order n. + let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); + let size: u8 = 1 << coset.log_2_size; + successors(Some(coset.shift.clone()), move |prev| Some(prev.clone() + generator_n.clone())) + .take(size.into()) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn coset_points_vector_has_right_size() { + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + assert_eq!(1 << coset.log_2_size, points.len()) + } + + #[test] + fn antipode_of_coset_point_is_in_coset() { + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + let point = points[2].clone(); + let anitpode_point = points[6].clone(); + assert_eq!(anitpode_point, point.antipode()) + } + + #[test] + fn coset_generator_has_right_order() { + let coset = Coset::new(2, CirclePoint::generator().mul(3)); + let generator_n = coset.get_generator(); + assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); + } } diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index 830b3e60c..f02e07dbd 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -1,3 +1,5 @@ pub mod point; pub mod errors; -pub mod cosets; \ No newline at end of file +pub mod cosets; +pub mod twiddles; +pub mod cfft; \ No newline at end of file diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index bea9f6329..c39ceb2ec 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -126,6 +126,14 @@ impl> CirclePoint{ } } + pub fn antipode(self) -> Self { + Self { + x: -self.x, + y: -self.y, + } + } + + pub fn eq(a: Self, b: Self) -> bool { a.x == b.x && a.y == b.y } @@ -136,6 +144,12 @@ impl> CirclePoint{ F::circle_generator().1 ).unwrap() } + + /// Returns the generator of the subgroup of order n = 2^log_2_size. + /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. + pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { + Self::generator().repeated_double(31 - log_2_size) + } pub fn group_order() -> u128 { F::ORDER @@ -236,4 +250,10 @@ mod tests { let g = G::generator(); assert_eq!(g.clone() + g.conjugate() , G::zero()) } + + #[test] + fn subgroup_generator_has_correct_order(){ + let generator_n = G::get_generator_of_subgroup(7); + assert_eq!(generator_n.repeated_double(7), G::zero()); + } } diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 8ae4782a3..69f04bbd2 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -1,22 +1,54 @@ -// fn compute_twiddles(domain: CircleDomain) -> Vec> { -// assert!(domain.log_n >= 1); -// let mut pts = domain.coset0().collect_vec(); -// reverse_slice_index_bits(&mut pts); -// let mut twiddles = vec![pts.iter().map(|p| p.y).collect_vec()]; -// if domain.log_n >= 2 { -// twiddles.push(pts.iter().step_by(2).map(|p| p.x).collect_vec()); -// for i in 0..(domain.log_n - 2) { -// let prev = twiddles.last().unwrap(); -// assert_eq!(prev.len(), 1 << (domain.log_n - 2 - i)); -// let cur = prev -// .iter() -// .step_by(2) -// .map(|x| x.square().double() - F::one()) -// .collect_vec(); -// twiddles.push(cur); -// } -// } -// twiddles -// } +use super::{cosets::Coset, point::CirclePoint}; +use crate::{fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}}; +pub fn get_twiddles(domain: Coset) -> Vec>> { + let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + in_place_bit_reverse_permute::>(&mut half_domain_points[..]); + let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; + + if domain.log_2_size >= 2 { + twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); + + for _ in 0..(domain.log_2_size - 2) { + let prev = twiddles.last().unwrap(); + let cur = prev + .iter() + .step_by(2) + .map(|x| x.square().double() - FieldElement::::one()) + .collect(); + twiddles.push(cur); + } + } + twiddles +} + +#[cfg(test)] +mod tests { + use super::*; + + // #[test] + // fn twiddles_vectors_lenght() { + // let domain = Coset::new_standard(3); + // let twiddles = get_twiddles(domain); + // for i in 0..twiddles.len() - 1 { + // assert_eq!(twiddles[i].len(), 2 * twiddles[i+1].len()) + // } + // } + + #[test] + fn twiddles_test() { + let domain = Coset::new_standard(3);g + let _twiddles = get_twiddles(domain.clone()); + // println!("DOMAIN: {:?}", Coset::get_coset_points(&domain)); + // println!("----------------------"); + // println!("TWIDDLES: {:?}", twiddles); + + assert_eq!(FieldElement::::from(&32768), + FieldElement::::from(&590768354).square().double() - FieldElement::::one() + ); + assert_eq!(-FieldElement::::from(&32768), + FieldElement::::from(&978592373).square().double() - FieldElement::::one() + ) + } +} \ No newline at end of file From a85eb0ffef84060207b2cbd655e3019e15f0997b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 2 Oct 2024 13:08:57 -0300 Subject: [PATCH 32/93] test cfft --- math/src/circle/cfft.rs | 92 +++++++++++++++++++++++++++++++------ math/src/circle/twiddles.rs | 17 +++---- 2 files changed, 86 insertions(+), 23 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 413fe83b9..0433a2362 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,5 +1,5 @@ -use crate::circle::{cosets::Coset, point::CirclePoint}; -use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use super::{cosets::Coset, point::CirclePoint, twiddles::get_twiddles}; +use crate::{fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}}; pub fn cfft(input: &mut [FieldElement], twiddles: Vec>>) @@ -7,6 +7,7 @@ pub fn cfft(input: &mut [FieldElement], twiddles: Vec], twiddles: Vec], twiddles: Vec; + + pub fn reverse_bits_len(x: usize, bit_len: usize) -> usize { + // NB: The only reason we need overflowing_shr() here as opposed + // to plain '>>' is to accommodate the case n == num_bits == 0, + // which would become `0 >> 64`. Rust thinks that any shift of 64 + // bits causes overflow, even when the argument is zero. + x.reverse_bits() + .overflowing_shr(usize::BITS - bit_len as u32) + .0 + } + + + fn cfft_permute_index(index: usize, log_n: usize) -> usize { + let (index, lsb) = (index >> 1, index & 1); + reverse_bits_len( + if lsb == 0 { + index + } else { + (1 << log_n) - index - 1 + }, + log_n, + ) + } + pub(crate) fn cfft_permute_slice(xs: &[T], log_2_size: usize) -> Vec { + (0..xs.len()) + .map(|i| xs[cfft_permute_index(i, log_2_size)].clone()) + .collect() + } + + fn evaluate_poly(coef: &[FpE;8], x: FpE, y: FpE) -> FpE { + coef[0] + + coef[1] * y + + coef[2] * x + + coef[3] * x * y + + coef[4] * (x.square().double() - FpE::one()) + + coef[5] * (x.square().double() - FpE::one()) * y + + coef[6] * ((x.square() * x).double() - x ) + + coef[7] * ((x.square() * x).double() - x ) * y + } + + #[test] + fn cfft_test() { + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + let twiddles = get_twiddles(coset); + let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4), FpE::from(5), FpE::from(6), FpE::from(7), FpE::from(8)]; + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly(&input, point.x, point.y); + expected_result.push(point_eval); + } + cfft(&mut input, twiddles); + let ordered_cfft_result = cfft_permute_slice(&mut input, 3); + assert_eq!(ordered_cfft_result, expected_result); + + } +} -// #[test] -// fn cfft() { -// let coset = Coset::new_standard(3); -// let twiddles = -// assert_eq!(1 << coset.log_2_size, points.len()) -// } -// } +/* +(1, y, => x, xy, +2xˆ2 - 1, 2xˆ2 y - y, 2xˆ3 - x, 2xˆ3 y - x y) +*/ \ No newline at end of file diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 69f04bbd2..5daddc203 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -20,6 +20,7 @@ pub fn get_twiddles(domain: Coset) -> Vec>> { twiddles.push(cur); } } + twiddles.reverse(); twiddles } @@ -38,17 +39,17 @@ mod tests { #[test] fn twiddles_test() { - let domain = Coset::new_standard(3);g - let _twiddles = get_twiddles(domain.clone()); + let domain = Coset::new_standard(3); + // let twiddles = get_twiddles(domain.clone()); // println!("DOMAIN: {:?}", Coset::get_coset_points(&domain)); // println!("----------------------"); // println!("TWIDDLES: {:?}", twiddles); - assert_eq!(FieldElement::::from(&32768), - FieldElement::::from(&590768354).square().double() - FieldElement::::one() - ); - assert_eq!(-FieldElement::::from(&32768), - FieldElement::::from(&978592373).square().double() - FieldElement::::one() - ) + // assert_eq!(FieldElement::::from(&32768), + // FieldElement::::from(&590768354).square().double() - FieldElement::::one() + // ); + // assert_eq!(-FieldElement::::from(&32768), + // FieldElement::::from(&978592373).square().double() - FieldElement::::one() + // ) } } \ No newline at end of file From 6bda70dce14b23c382b5d260ddcd72d3325f68e1 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 2 Oct 2024 15:23:37 -0300 Subject: [PATCH 33/93] test 16 not working --- math/src/circle/cfft.rs | 151 +++++++++++++++++++++++++++++++++------- 1 file changed, 126 insertions(+), 25 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 0433a2362..2e82f2b31 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,9 +1,13 @@ use super::{cosets::Coset, point::CirclePoint, twiddles::get_twiddles}; -use crate::{fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}}; - - -pub fn cfft(input: &mut [FieldElement], twiddles: Vec>>) -{ +use crate::{ + fft::cpu::bit_reversing::in_place_bit_reverse_permute, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; + +pub fn cfft( + input: &mut [FieldElement], + twiddles: Vec>>, +) { // divide input in groups, starting with 1, duplicating the number of groups in each stage. let mut group_count = 1; let mut group_size = input.len(); @@ -21,7 +25,7 @@ pub fn cfft(input: &mut [FieldElement], twiddles: Vec], twiddles: Vec usize { let (index, lsb) = (index >> 1, index & 1); reverse_bits_len( @@ -78,15 +80,50 @@ mod tests { .collect() } - fn evaluate_poly(coef: &[FpE;8], x: FpE, y: FpE) -> FpE { - coef[0] + - coef[1] * y + - coef[2] * x + - coef[3] * x * y + - coef[4] * (x.square().double() - FpE::one()) + - coef[5] * (x.square().double() - FpE::one()) * y + - coef[6] * ((x.square() * x).double() - x ) + - coef[7] * ((x.square() * x).double() - x ) * y + fn evaluate_poly(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { + coef[0] + + coef[1] * y + + coef[2] * x + + coef[3] * x * y + + coef[4] * (x.square().double() - FpE::one()) + + coef[5] * (x.square().double() - FpE::one()) * y + + coef[6] * ((x.square() * x).double() - x) + + coef[7] * ((x.square() * x).double() - x) * y + } + + fn evaluate_poly_16(coef: &[FpE; 16], x: FpE, y: FpE) -> FpE { + // v0 = 1 + // v1 = x + // v2 = 2x^2 - 1 + // v3 = 2(x^2 - 1)^2 - 1 + // v4 = 2((x^2 - 1)^2 - 1)^2 - 1 + let mut a = x; + let mut v = Vec::new(); + v.push(FpE::one()); + v.push(x); + for _ in 2..4 { + a = a.square() - FpE::one(); + v.push(a.double() + FpE::one()); + } + // println!("{:?}", coef[7] * y * v[1] * v[2]); + // println!("-------------------"); + // println!("{:?}", coef[7] * ((x.square() * x).double() - x) * y); + coef[0] * v[0] + + coef[1] * y * v[0] + + coef[2] * v[1] + + coef[3] * y * v[1] + + coef[4] * v[2] + + coef[5] * y * v[2] + + coef[6] * v[1] * v[2] + + coef[7] * y * v[1] * v[2] + + coef[8] * v[3] + + coef[9] * y * v[3] + + coef[10] * v[1] * v[3] + + coef[11] * y * v[1] * v[3] + + coef[12] * v[2] * v[3] + + coef[13] * y * v[2] * v[3] + + coef[14] * v[1] * v[2] * v[3] + + coef[15] * y * v[1] * v[2] * v[3] } #[test] @@ -94,20 +131,84 @@ mod tests { let coset = Coset::new_standard(3); let points = Coset::get_coset_points(&coset); let twiddles = get_twiddles(coset); - let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4), FpE::from(5), FpE::from(6), FpE::from(7), FpE::from(8)]; + let mut input = [ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + ]; let mut expected_result: Vec = Vec::new(); for point in points { - let point_eval = evaluate_poly(&input, point.x, point.y); - expected_result.push(point_eval); + let point_eval = evaluate_poly(&input, point.x, point.y); + expected_result.push(point_eval); } cfft(&mut input, twiddles); let ordered_cfft_result = cfft_permute_slice(&mut input, 3); assert_eq!(ordered_cfft_result, expected_result); - + } + + #[test] + fn cfft_test_16() { + let coset = Coset::new_standard(4); + let points = Coset::get_coset_points(&coset); + let twiddles = get_twiddles(coset); + let mut input = [ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + FpE::from(9), + FpE::from(10), + FpE::from(11), + FpE::from(12), + FpE::from(13), + FpE::from(14), + FpE::from(15), + FpE::from(16), + ]; + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_16(&input, point.x, point.y); + expected_result.push(point_eval); + } + cfft(&mut input, twiddles); + let ordered_cfft_result = cfft_permute_slice(&mut input, 4); + assert_eq!(ordered_cfft_result, expected_result); + } + + #[test] + fn print() { + let mut input = [ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + FpE::from(9), + FpE::from(10), + FpE::from(11), + FpE::from(12), + FpE::from(13), + FpE::from(14), + FpE::from(15), + FpE::from(16), + ]; + evaluate_poly_16(&input, FpE::from(20), FpE::from(33)); } } /* -(1, y, => x, xy, -2xˆ2 - 1, 2xˆ2 y - y, 2xˆ3 - x, 2xˆ3 y - x y) -*/ \ No newline at end of file +(1, y, => x, xy, +2xˆ2 - 1, 2xˆ2 y - y, 2xˆ3 - x, 2xˆ3 y - x y) +*/ From 7cbd17397592430ca05f1ee99d21260c599f8493 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 2 Oct 2024 15:53:54 -0300 Subject: [PATCH 34/93] fix n16 test --- math/src/circle/cfft.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 2e82f2b31..9e29faabb 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -102,8 +102,8 @@ mod tests { v.push(FpE::one()); v.push(x); for _ in 2..4 { - a = a.square() - FpE::one(); - v.push(a.double() + FpE::one()); + a = a.square().double() - FpE::one(); + v.push(a); } // println!("{:?}", coef[7] * y * v[1] * v[2]); // println!("-------------------"); From 4f60a488aadb88547388ee9957040a6b934c778e Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 3 Oct 2024 13:09:57 -0300 Subject: [PATCH 35/93] refactor --- math/src/circle/cfft.rs | 115 ++++++++++------------------------------ 1 file changed, 29 insertions(+), 86 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 9e29faabb..3b5a2263b 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,23 +1,13 @@ -use super::{cosets::Coset, point::CirclePoint, twiddles::get_twiddles}; -use crate::{ - fft::cpu::bit_reversing::in_place_bit_reverse_permute, - field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, -}; +use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; -pub fn cfft( +pub fn inplace_cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { - // divide input in groups, starting with 1, duplicating the number of groups in each stage. let mut group_count = 1; let mut group_size = input.len(); let mut round = 0; - // for each group, there'll be group_size / 2 butterflies. - // a butterfly is the atomic operation of a FFT, e.g: (a, b) = (a + wb, a - wb). - // The 0.5 factor is what gives FFT its performance, it recursively halves the problem size - // (group size). - while group_count < input.len() { let round_twiddles = &twiddles[round]; #[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm @@ -36,9 +26,6 @@ pub fn cfft( input[i] = y0; input[i + group_size / 2] = y1; } - - // input = [input_0 + y_0 * input_1, input_0 - y_0 * input_1, ] - // p(x_0, y_0) p(x_7, y_7) } group_count *= 2; group_size /= 2; @@ -46,40 +33,29 @@ pub fn cfft( } } +pub fn inplace_order_cfft_values(input: &mut [FieldElement]) { + for i in 0..input.len() { + let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros() as u32); + if cfft_index > i { + input.swap(i, cfft_index); + } + } +} + +pub fn reverse_cfft_index(index: usize, log_2_size: u32) -> usize { + let (mut new_index, lsb) = (index >> 1, index & 1); + if (lsb == 1) & (log_2_size > 1) { + new_index = (1 << log_2_size) - new_index - 1; + } + new_index.reverse_bits() >> (usize::BITS - log_2_size) +} + #[cfg(test)] mod tests { - use crate::circle::twiddles; - use super::*; + use crate::circle::{cosets::Coset, twiddles::get_twiddles}; type FpE = FieldElement; - pub fn reverse_bits_len(x: usize, bit_len: usize) -> usize { - // NB: The only reason we need overflowing_shr() here as opposed - // to plain '>>' is to accommodate the case n == num_bits == 0, - // which would become `0 >> 64`. Rust thinks that any shift of 64 - // bits causes overflow, even when the argument is zero. - x.reverse_bits() - .overflowing_shr(usize::BITS - bit_len as u32) - .0 - } - - fn cfft_permute_index(index: usize, log_n: usize) -> usize { - let (index, lsb) = (index >> 1, index & 1); - reverse_bits_len( - if lsb == 0 { - index - } else { - (1 << log_n) - index - 1 - }, - log_n, - ) - } - pub(crate) fn cfft_permute_slice(xs: &[T], log_2_size: usize) -> Vec { - (0..xs.len()) - .map(|i| xs[cfft_permute_index(i, log_2_size)].clone()) - .collect() - } - fn evaluate_poly(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { coef[0] + coef[1] * y @@ -92,11 +68,6 @@ mod tests { } fn evaluate_poly_16(coef: &[FpE; 16], x: FpE, y: FpE) -> FpE { - // v0 = 1 - // v1 = x - // v2 = 2x^2 - 1 - // v3 = 2(x^2 - 1)^2 - 1 - // v4 = 2((x^2 - 1)^2 - 1)^2 - 1 let mut a = x; let mut v = Vec::new(); v.push(FpE::one()); @@ -105,9 +76,7 @@ mod tests { a = a.square().double() - FpE::one(); v.push(a); } - // println!("{:?}", coef[7] * y * v[1] * v[2]); - // println!("-------------------"); - // println!("{:?}", coef[7] * ((x.square() * x).double() - x) * y); + coef[0] * v[0] + coef[1] * y * v[0] + coef[2] * v[1] @@ -146,9 +115,10 @@ mod tests { let point_eval = evaluate_poly(&input, point.x, point.y); expected_result.push(point_eval); } - cfft(&mut input, twiddles); - let ordered_cfft_result = cfft_permute_slice(&mut input, 3); - assert_eq!(ordered_cfft_result, expected_result); + inplace_cfft(&mut input, twiddles); + inplace_order_cfft_values(&mut input); + let result: &[FpE] = &input; + assert_eq!(result, expected_result); } #[test] @@ -179,36 +149,9 @@ mod tests { let point_eval = evaluate_poly_16(&input, point.x, point.y); expected_result.push(point_eval); } - cfft(&mut input, twiddles); - let ordered_cfft_result = cfft_permute_slice(&mut input, 4); - assert_eq!(ordered_cfft_result, expected_result); - } - - #[test] - fn print() { - let mut input = [ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), - FpE::from(9), - FpE::from(10), - FpE::from(11), - FpE::from(12), - FpE::from(13), - FpE::from(14), - FpE::from(15), - FpE::from(16), - ]; - evaluate_poly_16(&input, FpE::from(20), FpE::from(33)); + inplace_cfft(&mut input, twiddles); + inplace_order_cfft_values(&mut input); + let result: &[FpE] = &input; + assert_eq!(result, expected_result); } } - -/* -(1, y, => x, xy, -2xˆ2 - 1, 2xˆ2 y - y, 2xˆ3 - x, 2xˆ3 y - x y) -*/ From 95e87d0b87566b505b72df5da99fce656f89eb0e Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 11:57:18 -0300 Subject: [PATCH 36/93] clippy --- math/benches/fields/mersenne31.rs | 4 - math/src/circle/cfft.rs | 8 +- math/src/circle/cosets.rs | 39 +++++---- math/src/circle/point.rs | 85 ++++++++++--------- math/src/circle/twiddles.rs | 40 +++------ .../src/field/fields/mersenne31/extensions.rs | 10 ++- 6 files changed, 92 insertions(+), 94 deletions(-) diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index e1badefd8..e8d99d1c2 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -86,7 +86,6 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { .collect::>(); let mut group = c.benchmark_group("Mersenne31 operations"); - /* for i in input.clone().into_iter() { group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { bench.iter(|| { @@ -170,7 +169,6 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { }); }); } - */ for i in input.clone().into_iter() { group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { @@ -192,7 +190,6 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { }); } - /* for i in input.clone().into_iter() { group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { bench.iter(|| { @@ -257,5 +254,4 @@ pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { }); }); } - */ } diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 3b5a2263b..aae8ee1b1 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -18,10 +18,10 @@ pub fn inplace_cfft( let w = &round_twiddles[group]; // a twiddle factor is used per group for i in first_in_group..first_in_next_group { - let wi = w * &input[i + group_size / 2]; + let wi = w * input[i + group_size / 2]; - let y0 = &input[i] + &wi; - let y1 = &input[i] - &wi; + let y0 = input[i] + wi; + let y1 = input[i] - wi; input[i] = y0; input[i + group_size / 2] = y1; @@ -35,7 +35,7 @@ pub fn inplace_cfft( pub fn inplace_order_cfft_values(input: &mut [FieldElement]) { for i in 0..input.len() { - let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros() as u32); + let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros()); if cfft_index > i { input.swap(i, cfft_index); } diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index d01b449b9..52b9d67b6 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -2,7 +2,6 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; use std::iter::successors; - #[derive(Debug, Clone)] pub struct Coset { // Coset: shift + where n = 2^{log_2_size}. @@ -13,41 +12,49 @@ pub struct Coset { impl Coset { pub fn new(log_2_size: u32, shift: CirclePoint) -> Self { - Coset{ log_2_size, shift } + Coset { log_2_size, shift } } /// Returns the coset g_2n + pub fn new_standard(log_2_size: u32) -> Self { // shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}. - let shift = CirclePoint::get_generator_of_subgroup((log_2_size as u32) + 1); - Coset{ log_2_size, shift } + let shift = CirclePoint::get_generator_of_subgroup(log_2_size + 1); + Coset { log_2_size, shift } } - + /// Returns g_n, the generator of the subgroup of order n = 2^log_2_size. pub fn get_generator(&self) -> CirclePoint { - CirclePoint::generator().repeated_double(31 - self.log_2_size as u32) + CirclePoint::generator().repeated_double(31 - self.log_2_size) } - /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + + /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + pub fn half_coset(coset: Self) -> Self { - Coset { log_2_size: coset.log_2_size - 1, shift: coset.shift } - } + Coset { + log_2_size: coset.log_2_size - 1, + shift: coset.shift, + } + } /// Given a coset shift + G returns the coset -shift + G. /// Note that (g_2n + ) U (-g_2n + ) = g_2n + . pub fn conjugate(coset: Self) -> Self { - Coset { log_2_size: coset.log_2_size, shift: coset.shift.conjugate() } + Coset { + log_2_size: coset.log_2_size, + shift: coset.shift.conjugate(), + } } /// Returns the vector of shift + g for every g in . /// where g = i * g_n for i = 0, ..., n-1. pub fn get_coset_points(coset: &Self) -> Vec> { - // g_n the generator of the subgroup of order n. + // g_n the generator of the subgroup of order n. let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); let size: u8 = 1 << coset.log_2_size; - successors(Some(coset.shift.clone()), move |prev| Some(prev.clone() + generator_n.clone())) - .take(size.into()) - .collect() + successors(Some(coset.shift.clone()), move |prev| { + Some(prev.clone() + generator_n.clone()) + }) + .take(size.into()) + .collect() } } @@ -70,11 +77,11 @@ mod tests { let anitpode_point = points[6].clone(); assert_eq!(anitpode_point, point.antipode()) } - + #[test] fn coset_generator_has_right_order() { let coset = Coset::new(2, CirclePoint::generator().mul(3)); - let generator_n = coset.get_generator(); + let generator_n = coset.get_generator(); assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); } } diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index c39ceb2ec..ca01c443c 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,10 +1,12 @@ -use crate::field::traits::IsField; -use crate::field::{element::FieldElement, fields::mersenne31::{field::Mersenne31Field, extension::Degree4ExtensionField}}; use super::errors::CircleError; +use crate::field::traits::IsField; +use crate::field::{ + element::FieldElement, + fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, +}; use std::cmp::PartialEq; -use std::ops::Add; -use std::process::Output; use std::fmt::Debug; +use std::ops::Add; #[derive(Debug, Clone)] pub struct CirclePoint { @@ -19,27 +21,26 @@ pub trait HasCircleParams { const ORDER: u128; } - impl HasCircleParams for Mersenne31Field { type FE = FieldElement; // This could be a constant instead of a function - fn circle_generator() -> (Self::FE, Self::FE){ - ( - Self::FE::from(&2), - Self::FE::from(&1268011823) - ) + fn circle_generator() -> (Self::FE, Self::FE) { + (Self::FE::from(&2), Self::FE::from(&1268011823)) } - + /// ORDER = 2^31 const ORDER: u128 = 2147483648; } impl HasCircleParams for Degree4ExtensionField { type FE = FieldElement; - + // This could be a constant instead of a function - fn circle_generator() -> (FieldElement, FieldElement){ + fn circle_generator() -> ( + FieldElement, + FieldElement, + ) { ( Degree4ExtensionField::from_coeffcients( FieldElement::::one(), @@ -47,14 +48,12 @@ impl HasCircleParams for Degree4ExtensionField { FieldElement::::from(&478637715), FieldElement::::from(&513582971), ), - Degree4ExtensionField::from_coeffcients( FieldElement::::from(992285211), FieldElement::::from(649143431), FieldElement::::from(&740191619), - FieldElement::::from(&1186584352) - - ) + FieldElement::::from(&1186584352), + ), ) } @@ -62,7 +61,7 @@ impl HasCircleParams for Degree4ExtensionField { const ORDER: u128 = 21267647892944572736998860269687930880; } -impl> CirclePoint{ +impl> CirclePoint { pub fn new(x: FieldElement, y: FieldElement) -> Result { if x.square() + y.square() == FieldElement::one() { Ok(CirclePoint { x, y }) @@ -71,7 +70,7 @@ impl> CirclePoint{ } } - /// Neutral element of the Circle group (with additive notation). + /// Neutral element of the Circle group (with additive notation). pub fn zero() -> Self { Self::new(FieldElement::one(), FieldElement::zero()).unwrap() } @@ -80,9 +79,8 @@ impl> CirclePoint{ pub fn add(a: Self, b: Self) -> Self { let x = &a.x * &b.x - &a.y * &b.y; let y = a.x * b.y + a.y * b.x; - CirclePoint{ x, y } + CirclePoint { x, y } } - /// Computes n * (x, y) = (x ,y) + ... + (x, y) n-times. pub fn mul(self, mut scalar: u128) -> Self { @@ -90,7 +88,7 @@ impl> CirclePoint{ let mut cur = self; loop { if scalar == 0 { - return res + return res; } if scalar & 1 == 1 { res = res + cur.clone(); @@ -99,13 +97,14 @@ impl> CirclePoint{ scalar >>= 1; } } - + /// Computes 2(x, y) = (2x^2 - 1, 2xy). pub fn double(self) -> Self { Self::new( self.x.square().double() - FieldElement::one(), self.x.double() * self.y, - ).unwrap() + ) + .unwrap() } /// Computes 2^n * (x, y). @@ -133,18 +132,14 @@ impl> CirclePoint{ } } - pub fn eq(a: Self, b: Self) -> bool { a.x == b.x && a.y == b.y } pub fn generator() -> Self { - CirclePoint::new( - F::circle_generator().0, - F::circle_generator().1 - ).unwrap() + CirclePoint::new(F::circle_generator().0, F::circle_generator().1).unwrap() } - + /// Returns the generator of the subgroup of order n = 2^log_2_size. /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { @@ -183,38 +178,44 @@ mod tests { #[test] fn create_new_valid_g_point() { let valid_point = G::new(FE::one(), FE::zero()).unwrap(); - let expected = G { x: FE::one(), y: FE::zero() }; + let expected = G { + x: FE::one(), + y: FE::zero(), + }; assert_eq!(valid_point, expected) } #[test] fn create_new_valid_g4_point() { let valid_point = G4::new(Fp4E::one(), Fp4E::zero()).unwrap(); - let expected = G4 { x: Fp4E::one(), y: Fp4E::zero() }; + let expected = G4 { + x: Fp4E::one(), + y: Fp4E::zero(), + }; assert_eq!(valid_point, expected) } #[test] fn create_new_invalid_circle_point() { - let invalid_point = G::new(FE::one(), FE::one()); + let invalid_point = G::new(FE::one(), FE::one()); assert!(invalid_point.is_err()) } #[test] fn create_new_invalid_g4_circle_point() { - let invalid_point = G4::new(Fp4E::one(), Fp4E::one()); + let invalid_point = G4::new(Fp4E::one(), Fp4E::one()); assert!(invalid_point.is_err()) } #[test] fn zero_plus_zero_is_zero() { let a = G::zero(); - let b = G::zero(); + let b = G::zero(); assert_eq!(a + b, G::zero()) } #[test] - fn generator_plus_zero_is_generator(){ + fn generator_plus_zero_is_generator() { let g = G::generator(); let zero = G::zero(); assert_eq!(g.clone() + zero, g) @@ -227,32 +228,32 @@ mod tests { } #[test] - fn mul_eight_equals_double_three_times(){ + fn mul_eight_equals_double_three_times() { let g = G::generator(); assert_eq!(g.clone().repeated_double(3), G::mul(g, 8)) } #[test] - fn generator_g1_has_order_two_pow_31 (){ + fn generator_g1_has_order_two_pow_31() { let g = G::generator(); let n = 31; assert_eq!(g.repeated_double(n), G::zero()) } #[test] - fn generator_g4_has_the_order_of_the_group (){ + fn generator_g4_has_the_order_of_the_group() { let g = G4::generator(); assert_eq!(g.mul(G4::group_order()), G4::zero()) } #[test] - fn conjugation_is_inverse_operation () { + fn conjugation_is_inverse_operation() { let g = G::generator(); - assert_eq!(g.clone() + g.conjugate() , G::zero()) + assert_eq!(g.clone() + g.conjugate(), G::zero()) } #[test] - fn subgroup_generator_has_correct_order(){ + fn subgroup_generator_has_correct_order() { let generator_n = G::get_generator_of_subgroup(7); assert_eq!(generator_n.repeated_double(7), G::zero()); } diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 5daddc203..d72a02025 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -1,11 +1,15 @@ use super::{cosets::Coset, point::CirclePoint}; -use crate::{fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}}; +use crate::{ + fft::cpu::bit_reversing::in_place_bit_reverse_permute, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; pub fn get_twiddles(domain: Coset) -> Vec>> { - let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - in_place_bit_reverse_permute::>(&mut half_domain_points[..]); + let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + in_place_bit_reverse_permute::>(&mut half_domain_points[..]); - let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; + let mut twiddles: Vec>> = + vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); @@ -28,28 +32,12 @@ pub fn get_twiddles(domain: Coset) -> Vec>> { mod tests { use super::*; - // #[test] - // fn twiddles_vectors_lenght() { - // let domain = Coset::new_standard(3); - // let twiddles = get_twiddles(domain); - // for i in 0..twiddles.len() - 1 { - // assert_eq!(twiddles[i].len(), 2 * twiddles[i+1].len()) - // } - // } - #[test] - fn twiddles_test() { + fn twiddles_vectors_lenght() { let domain = Coset::new_standard(3); - // let twiddles = get_twiddles(domain.clone()); - // println!("DOMAIN: {:?}", Coset::get_coset_points(&domain)); - // println!("----------------------"); - // println!("TWIDDLES: {:?}", twiddles); - - // assert_eq!(FieldElement::::from(&32768), - // FieldElement::::from(&590768354).square().double() - FieldElement::::one() - // ); - // assert_eq!(-FieldElement::::from(&32768), - // FieldElement::::from(&978592373).square().double() - FieldElement::::one() - // ) + let twiddles = get_twiddles(domain); + for i in 0..twiddles.len() - 1 { + assert_eq!(twiddles[i].len(), 2 * twiddles[i + 1].len()) + } } -} \ No newline at end of file +} diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs index 27c2ab118..2ec853ec0 100644 --- a/math/src/field/fields/mersenne31/extensions.rs +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -8,6 +8,8 @@ use crate::field::{ use alloc::vec::Vec; type FpE = FieldElement; +type Fp2E = FieldElement; +type Fp4E = FieldElement; #[derive(Clone, Debug)] pub struct Degree2ExtensionField; @@ -132,11 +134,15 @@ impl IsSubFieldOf for Mersenne31Field { } } -type Fp2E = FieldElement; - #[derive(Clone, Debug)] pub struct Degree4ExtensionField; +impl Degree4ExtensionField { + pub fn from_coeffcients(a: FpE, b: FpE, c: FpE, d: FpE) -> Fp4E { + Fp4E::new([Fp2E::new([a, b]), Fp2E::new([c, d])]) + } +} + impl IsField for Degree4ExtensionField { type BaseType = [Fp2E; 2]; From 10636f9c6a41785dcec4fd8d83661c99ce7825a2 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 11:58:24 -0300 Subject: [PATCH 37/93] fmt --- math/src/circle/errors.rs | 1 - math/src/circle/mod.rs | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/math/src/circle/errors.rs b/math/src/circle/errors.rs index d2f569d19..07b63ec70 100644 --- a/math/src/circle/errors.rs +++ b/math/src/circle/errors.rs @@ -1,4 +1,3 @@ - #[derive(Debug)] pub enum CircleError { InvalidValue, diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index f02e07dbd..4876f728a 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -1,5 +1,5 @@ -pub mod point; -pub mod errors; +pub mod cfft; pub mod cosets; +pub mod errors; +pub mod point; pub mod twiddles; -pub mod cfft; \ No newline at end of file From 876459d1dc0d81379d16f18b13ca7f701d5be3ad Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 16:34:51 -0300 Subject: [PATCH 38/93] clippy --- math/src/circle/cosets.rs | 2 +- math/src/circle/point.rs | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 52b9d67b6..ffff65428 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -80,7 +80,7 @@ mod tests { #[test] fn coset_generator_has_right_order() { - let coset = Coset::new(2, CirclePoint::generator().mul(3)); + let coset = Coset::new(2, CirclePoint::generator().scalar_mul(3)); let generator_n = coset.get_generator(); assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); } diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index ca01c443c..b7340a018 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -76,6 +76,7 @@ impl> CirclePoint { } /// Computes (a0, a1) + (b0, b1) = (a0 * b0 - a1 * b1, a0 * b1 + a1 * b0) + #[allow(clippy::should_implement_trait)] pub fn add(a: Self, b: Self) -> Self { let x = &a.x * &b.x - &a.y * &b.y; let y = a.x * b.y + a.y * b.x; @@ -83,7 +84,7 @@ impl> CirclePoint { } /// Computes n * (x, y) = (x ,y) + ... + (x, y) n-times. - pub fn mul(self, mut scalar: u128) -> Self { + pub fn scalar_mul(self, mut scalar: u128) -> Self { let mut res = Self::zero(); let mut cur = self; loop { @@ -224,13 +225,13 @@ mod tests { #[test] fn double_equals_mul_two() { let g = G::generator(); - assert_eq!(g.clone().double(), G::mul(g, 2)) + assert_eq!(g.clone().double(), G::scalar_mul(g, 2)) } #[test] fn mul_eight_equals_double_three_times() { let g = G::generator(); - assert_eq!(g.clone().repeated_double(3), G::mul(g, 8)) + assert_eq!(g.clone().repeated_double(3), G::scalar_mul(g, 8)) } #[test] @@ -243,7 +244,7 @@ mod tests { #[test] fn generator_g4_has_the_order_of_the_group() { let g = G4::generator(); - assert_eq!(g.mul(G4::group_order()), G4::zero()) + assert_eq!(g.scalar_mul(G4::group_order()), G4::zero()) } #[test] From 076b683e6a12dd79af65cf3a88f0f30cf470687d Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 17:02:56 -0300 Subject: [PATCH 39/93] rm std --- math/src/circle/cfft.rs | 1 + math/src/circle/cosets.rs | 1 + math/src/circle/point.rs | 4 +--- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index aae8ee1b1..58851c965 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,4 +1,5 @@ use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use alloc::vec::Vec; pub fn inplace_cfft( input: &mut [FieldElement], diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index ffff65428..4f87ce6b4 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,5 +1,6 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; +use alloc::vec::Vec; use std::iter::successors; #[derive(Debug, Clone)] diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index b7340a018..07617a66a 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -4,9 +4,7 @@ use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use std::cmp::PartialEq; -use std::fmt::Debug; -use std::ops::Add; +use core::ops::Add; #[derive(Debug, Clone)] pub struct CirclePoint { From 0d7d89f6b26cf70e252adf1b932250f94b659066 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 17:51:36 -0300 Subject: [PATCH 40/93] add alloc --- math/src/circle/cfft.rs | 3 ++- math/src/circle/cosets.rs | 6 +++--- math/src/circle/twiddles.rs | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 58851c965..f6ef7efaf 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,6 +1,7 @@ +extern crate alloc; use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; -use alloc::vec::Vec; +#[cfg(feature = "alloc")] pub fn inplace_cfft( input: &mut [FieldElement], twiddles: Vec>>, diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 4f87ce6b4..0cbf1a5d7 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,7 +1,6 @@ +extern crate alloc; use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; -use alloc::vec::Vec; -use std::iter::successors; #[derive(Debug, Clone)] pub struct Coset { @@ -47,11 +46,12 @@ impl Coset { /// Returns the vector of shift + g for every g in . /// where g = i * g_n for i = 0, ..., n-1. + #[cfg(feature = "alloc")] pub fn get_coset_points(coset: &Self) -> Vec> { // g_n the generator of the subgroup of order n. let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); let size: u8 = 1 << coset.log_2_size; - successors(Some(coset.shift.clone()), move |prev| { + core::iter::successors(Some(coset.shift.clone()), move |prev| { Some(prev.clone() + generator_n.clone()) }) .take(size.into()) diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index d72a02025..2137ce316 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -1,9 +1,11 @@ -use super::{cosets::Coset, point::CirclePoint}; +extern crate alloc; use crate::{ + circle::{cosets::Coset, point::CirclePoint}, fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +#[cfg(feature = "alloc")] pub fn get_twiddles(domain: Coset) -> Vec>> { let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); in_place_bit_reverse_permute::>(&mut half_domain_points[..]); From 9685a743469035962c162e96e8f6d621780e4175 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 4 Oct 2024 17:55:35 -0300 Subject: [PATCH 41/93] fix --- math/src/circle/cosets.rs | 1 + math/src/circle/twiddles.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 0cbf1a5d7..0da5f204b 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,6 +1,7 @@ extern crate alloc; use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; +use alloc::vec::Vec; #[derive(Debug, Clone)] pub struct Coset { diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 2137ce316..b6815c574 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -4,6 +4,7 @@ use crate::{ fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +use alloc::vec::Vec; #[cfg(feature = "alloc")] pub fn get_twiddles(domain: Coset) -> Vec>> { From ac842bf9a6abadbed0ecad39382e48216f545ab8 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 7 Oct 2024 18:01:34 -0300 Subject: [PATCH 42/93] wip --- math/src/circle/cfft.rs | 108 +---------------------- math/src/circle/mod.rs | 1 + math/src/circle/point.rs | 1 + math/src/circle/polynomial.rs | 161 ++++++++++++++++++++++++++++++++++ math/src/circle/twiddles.rs | 20 +++-- 5 files changed, 180 insertions(+), 111 deletions(-) create mode 100644 math/src/circle/polynomial.rs diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index f6ef7efaf..cc3862bcd 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -6,6 +6,8 @@ pub fn inplace_cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { + use super::twiddles::TwiddlesConfig; + let mut group_count = 1; let mut group_size = input.len(); let mut round = 0; @@ -51,109 +53,3 @@ pub fn reverse_cfft_index(index: usize, log_2_size: u32) -> usize { } new_index.reverse_bits() >> (usize::BITS - log_2_size) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::circle::{cosets::Coset, twiddles::get_twiddles}; - type FpE = FieldElement; - - fn evaluate_poly(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { - coef[0] - + coef[1] * y - + coef[2] * x - + coef[3] * x * y - + coef[4] * (x.square().double() - FpE::one()) - + coef[5] * (x.square().double() - FpE::one()) * y - + coef[6] * ((x.square() * x).double() - x) - + coef[7] * ((x.square() * x).double() - x) * y - } - - fn evaluate_poly_16(coef: &[FpE; 16], x: FpE, y: FpE) -> FpE { - let mut a = x; - let mut v = Vec::new(); - v.push(FpE::one()); - v.push(x); - for _ in 2..4 { - a = a.square().double() - FpE::one(); - v.push(a); - } - - coef[0] * v[0] - + coef[1] * y * v[0] - + coef[2] * v[1] - + coef[3] * y * v[1] - + coef[4] * v[2] - + coef[5] * y * v[2] - + coef[6] * v[1] * v[2] - + coef[7] * y * v[1] * v[2] - + coef[8] * v[3] - + coef[9] * y * v[3] - + coef[10] * v[1] * v[3] - + coef[11] * y * v[1] * v[3] - + coef[12] * v[2] * v[3] - + coef[13] * y * v[2] * v[3] - + coef[14] * v[1] * v[2] * v[3] - + coef[15] * y * v[1] * v[2] * v[3] - } - - #[test] - fn cfft_test() { - let coset = Coset::new_standard(3); - let points = Coset::get_coset_points(&coset); - let twiddles = get_twiddles(coset); - let mut input = [ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), - ]; - let mut expected_result: Vec = Vec::new(); - for point in points { - let point_eval = evaluate_poly(&input, point.x, point.y); - expected_result.push(point_eval); - } - inplace_cfft(&mut input, twiddles); - inplace_order_cfft_values(&mut input); - let result: &[FpE] = &input; - assert_eq!(result, expected_result); - } - - #[test] - fn cfft_test_16() { - let coset = Coset::new_standard(4); - let points = Coset::get_coset_points(&coset); - let twiddles = get_twiddles(coset); - let mut input = [ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), - FpE::from(9), - FpE::from(10), - FpE::from(11), - FpE::from(12), - FpE::from(13), - FpE::from(14), - FpE::from(15), - FpE::from(16), - ]; - let mut expected_result: Vec = Vec::new(); - for point in points { - let point_eval = evaluate_poly_16(&input, point.x, point.y); - expected_result.push(point_eval); - } - inplace_cfft(&mut input, twiddles); - inplace_order_cfft_values(&mut input); - let result: &[FpE] = &input; - assert_eq!(result, expected_result); - } -} diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index 4876f728a..b76831d65 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -3,3 +3,4 @@ pub mod cosets; pub mod errors; pub mod point; pub mod twiddles; +pub mod polynomial; \ No newline at end of file diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 07617a66a..92b6200f5 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -117,6 +117,7 @@ impl> CirclePoint { /// Computes the inverse of the point. /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation. + /// because the norm of every point in the circle is one. pub fn conjugate(self) -> Self { Self { x: self.x, diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs new file mode 100644 index 000000000..2937897e4 --- /dev/null +++ b/math/src/circle/polynomial.rs @@ -0,0 +1,161 @@ +use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; + +use super::{cfft::{inplace_cfft, inplace_order_cfft_values}, cosets::Coset, twiddles::{get_twiddles, TwiddlesConfig}}; + +/// Given the 2^n coefficients of a two-variables polynomial in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} +/// returns the evaluation of the polynomianl on the points of the standard coset of size 2^n. +/// Note that coeff has to be a vector with length a power of two 2^n. +pub fn evaluate_cfft(mut coeff: Vec>) -> Vec>{ + let domain_log_2_size: u32 = coeff.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let config = TwiddlesConfig::Evaluation; + let twiddles = get_twiddles(coset, config); + + inplace_cfft(&mut coeff, twiddles); + inplace_order_cfft_values(&mut coeff); + coeff +} + +/// Interpolates the 2^n evaluations of a two-variables polynomial on the points of the standard coset of size 2^n. +/// As a result we obtain the coefficients of the polynomial in the basis: {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} +/// Note that eval has to be a vector of length a power of two 2^n. +pub fn interpolate_cfft(mut eval: Vec>) -> Vec>{ + let domain_log_2_size: u32 = eval.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let config = TwiddlesConfig::Interpolation; + let twiddles = get_twiddles(coset, config); + + inplace_cfft(&mut eval, twiddles); + inplace_order_cfft_values(&mut eval); + eval +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circle::cosets::Coset; + type FpE = FieldElement; + + fn evaluate_poly(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { + coef[0] + + coef[1] * y + + coef[2] * x + + coef[3] * x * y + + coef[4] * (x.square().double() - FpE::one()) + + coef[5] * (x.square().double() - FpE::one()) * y + + coef[6] * ((x.square() * x).double() - x) + + coef[7] * ((x.square() * x).double() - x) * y + } + + fn evaluate_poly_16(coef: &[FpE; 16], x: FpE, y: FpE) -> FpE { + let mut a = x; + let mut v = Vec::new(); + v.push(FpE::one()); + v.push(x); + for _ in 2..4 { + a = a.square().double() - FpE::one(); + v.push(a); + } + + coef[0] * v[0] + + coef[1] * y * v[0] + + coef[2] * v[1] + + coef[3] * y * v[1] + + coef[4] * v[2] + + coef[5] * y * v[2] + + coef[6] * v[1] * v[2] + + coef[7] * y * v[1] * v[2] + + coef[8] * v[3] + + coef[9] * y * v[3] + + coef[10] * v[1] * v[3] + + coef[11] * y * v[1] * v[3] + + coef[12] * v[2] * v[3] + + coef[13] * y * v[2] * v[3] + + coef[14] * v[1] * v[2] * v[3] + + coef[15] * y * v[1] * v[2] * v[3] + } + + #[test] + fn cfft_evaluation_8_points() { + // We create the coset points and evaluate them without the fft. + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + let mut input = [ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + ]; + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly(&input, point.x, point.y); + expected_result.push(point_eval); + } + + let result = evaluate_cfft(input.to_vec()); + let slice_result: &[FpE] = &result; + assert_eq!(slice_result, expected_result); + } + + #[test] + fn cfft_evaluation_16_points() { + let coset = Coset::new_standard(4); + let points = Coset::get_coset_points(&coset); + let mut input = [ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + FpE::from(9), + FpE::from(10), + FpE::from(11), + FpE::from(12), + FpE::from(13), + FpE::from(14), + FpE::from(15), + FpE::from(16), + ]; + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_16(&input, point.x, point.y); + expected_result.push(point_eval); + } + + let result = evaluate_cfft(input.to_vec()); + let slice_result: &[FpE] = &result; + assert_eq!(slice_result, expected_result); + } + + #[test] + fn evaluate_and_interpolate_8() { + // + let coeff = vec![ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + ]; + + + let evals = evaluate_cfft(coeff.clone()); + let factor = FpE::from(8).inv().unwrap(); + let mut new_coeff = interpolate_cfft(evals); + new_coeff = new_coeff.iter() + .map(|coeff| factor * coeff) + .collect(); + assert_eq!(new_coeff, coeff); + } +} + diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index b6815c574..bf0939bfb 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -6,17 +6,20 @@ use crate::{ }; use alloc::vec::Vec; +#[derive(PartialEq)] +pub enum TwiddlesConfig { + Evaluation, + Interpolation, +} #[cfg(feature = "alloc")] -pub fn get_twiddles(domain: Coset) -> Vec>> { +pub fn get_twiddles(domain: Coset, config: TwiddlesConfig) -> Vec>> { let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); in_place_bit_reverse_permute::>(&mut half_domain_points[..]); - let mut twiddles: Vec>> = - vec![half_domain_points.iter().map(|p| p.y).collect()]; + let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); - for _ in 0..(domain.log_2_size - 2) { let prev = twiddles.last().unwrap(); let cur = prev @@ -28,6 +31,12 @@ pub fn get_twiddles(domain: Coset) -> Vec>> { } } twiddles.reverse(); + + if config == TwiddlesConfig::Interpolation { + twiddles.iter_mut().for_each(|x| { + FieldElement::::inplace_batch_inverse(x).unwrap(); + }); + } twiddles } @@ -38,7 +47,8 @@ mod tests { #[test] fn twiddles_vectors_lenght() { let domain = Coset::new_standard(3); - let twiddles = get_twiddles(domain); + let config = TwiddlesConfig::Evaluation; + let twiddles = get_twiddles(domain, config); for i in 0..twiddles.len() - 1 { assert_eq!(twiddles[i].len(), 2 * twiddles[i + 1].len()) } From 5f97990186bcc0079a7872d440b6680336ca2916 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 14 Oct 2024 18:52:05 -0300 Subject: [PATCH 43/93] add hand iterpolation for 4 and 8 --- math/src/circle/cfft.rs | 71 ++++++++++++++++++++- math/src/circle/polynomial.rs | 113 ++++++++++++++++++++++++++++------ math/src/circle/twiddles.rs | 37 ++++++++++- 3 files changed, 198 insertions(+), 23 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index cc3862bcd..791a8d7ab 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -6,8 +6,6 @@ pub fn inplace_cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { - use super::twiddles::TwiddlesConfig; - let mut group_count = 1; let mut group_size = input.len(); let mut round = 0; @@ -37,6 +35,75 @@ pub fn inplace_cfft( } } +pub fn cfft_4( + input: &mut [FieldElement], + twiddles: Vec>>, +) -> Vec> { + let mut stage1: Vec> = Vec::with_capacity(4); + + stage1.push(input[0] + input[1]); + stage1.push((input[0] - input[1]) * twiddles[0][0]); + + stage1.push(input[2] + input[3]); + stage1.push((input[2] - input[3]) * twiddles[0][1]); + + let mut stage2: Vec> = Vec::with_capacity(4); + + stage2.push(stage1[0] + stage1[2]); + stage2.push(stage1[1] + stage1[3]); + + stage2.push((stage1[0] - stage1[2]) * twiddles[1][0]); + stage2.push((stage1[1] - stage1[3]) * twiddles[1][0]); + + let f = FieldElement::::from(4).inv().unwrap(); + stage2.into_iter().map(|elem| elem * f).collect() +} + +pub fn cfft_8( + input: &mut [FieldElement], + twiddles: Vec>>, +) -> Vec> { + let mut stage1: Vec> = Vec::with_capacity(8); + + stage1.push(input[0] + input[4]); + stage1.push(input[1] + input[5]); + stage1.push(input[2] + input[6]); + stage1.push(input[3] + input[7]); + stage1.push((input[0] - input[4]) * twiddles[0][0]); + stage1.push((input[1] - input[5]) * twiddles[0][1]); + stage1.push((input[2] - input[6]) * twiddles[0][2]); + stage1.push((input[3] - input[7]) * twiddles[0][3]); + + let mut stage2: Vec> = Vec::with_capacity(8); + + stage2.push(stage1[0] + stage1[2]); + stage2.push(stage1[1] + stage1[3]); + stage2.push((stage1[0] - stage1[2]) * twiddles[1][0]); + stage2.push((stage1[1] - stage1[3]) * twiddles[1][1]); + + stage2.push(stage1[4] + stage1[6]); + stage2.push(stage1[5] + stage1[7]); + stage2.push((stage1[4] - stage1[6]) * twiddles[1][0]); + stage2.push((stage1[5] - stage1[7]) * twiddles[1][1]); + + let mut stage3: Vec> = Vec::with_capacity(8); + + stage3.push(stage2[0] + stage2[1]); + stage3.push((stage2[0] - stage2[1]) * twiddles[2][0]); + + stage3.push(stage2[2] + stage2[3]); + stage3.push((stage2[2] - stage2[3]) * twiddles[2][0]); + + stage3.push(stage2[4] + stage2[5]); + stage3.push((stage2[4] - stage2[5]) * twiddles[2][0]); + + stage3.push(stage2[6] + stage2[7]); + stage3.push((stage2[6] - stage2[7]) * twiddles[2][0]); + + let f = FieldElement::::from(8).inv().unwrap(); + stage3.into_iter().map(|elem| elem * f).collect() +} + pub fn inplace_order_cfft_values(input: &mut [FieldElement]) { for i in 0..input.len() { let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros()); diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 2937897e4..e7c2ebaf4 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,11 +1,19 @@ use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; -use super::{cfft::{inplace_cfft, inplace_order_cfft_values}, cosets::Coset, twiddles::{get_twiddles, TwiddlesConfig}}; +use super::{ + cfft::{cfft_4, cfft_8, inplace_cfft, inplace_order_cfft_values}, + cosets::Coset, + twiddles::{ + get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig, + }, +}; /// Given the 2^n coefficients of a two-variables polynomial in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomianl on the points of the standard coset of size 2^n. /// Note that coeff has to be a vector with length a power of two 2^n. -pub fn evaluate_cfft(mut coeff: Vec>) -> Vec>{ +pub fn evaluate_cfft( + mut coeff: Vec>, +) -> Vec> { let domain_log_2_size: u32 = coeff.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Evaluation; @@ -19,24 +27,52 @@ pub fn evaluate_cfft(mut coeff: Vec>) -> Vec>) -> Vec>{ +pub fn interpolate_cfft( + mut eval: Vec>, +) -> Vec> { let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Interpolation; let twiddles = get_twiddles(coset, config); - + inplace_cfft(&mut eval, twiddles); inplace_order_cfft_values(&mut eval); eval } +pub fn interpolate_4( + mut eval: Vec>, +) -> Vec> { + let domain_log_2_size: u32 = eval.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let twiddles = get_twiddles_itnerpolation_4(coset); + + let res = cfft_4(&mut eval, twiddles); + res +} + +pub fn interpolate_8( + mut eval: Vec>, +) -> Vec> { + let domain_log_2_size: u32 = eval.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let twiddles = get_twiddles_itnerpolation_8(coset); + + let res = cfft_8(&mut eval, twiddles); + res +} + #[cfg(test)] mod tests { use super::*; use crate::circle::cosets::Coset; type FpE = FieldElement; - fn evaluate_poly(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { + fn evaluate_poly_4(coef: &[FpE; 4], x: FpE, y: FpE) -> FpE { + coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y + } + + fn evaluate_poly_8(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { coef[0] + coef[1] * y + coef[2] * x @@ -75,6 +111,23 @@ mod tests { + coef[15] * y * v[1] * v[2] * v[3] } + #[test] + fn cfft_evaluation_4_points() { + // We create the coset points and evaluate them without the fft. + let coset = Coset::new_standard(2); + let points = Coset::get_coset_points(&coset); + let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_4(&input, point.x, point.y); + expected_result.push(point_eval); + } + + let result = evaluate_cfft(input.to_vec()); + let slice_result: &[FpE] = &result; + assert_eq!(slice_result, expected_result); + } + #[test] fn cfft_evaluation_8_points() { // We create the coset points and evaluate them without the fft. @@ -92,10 +145,10 @@ mod tests { ]; let mut expected_result: Vec = Vec::new(); for point in points { - let point_eval = evaluate_poly(&input, point.x, point.y); + let point_eval = evaluate_poly_8(&input, point.x, point.y); expected_result.push(point_eval); } - + let result = evaluate_cfft(input.to_vec()); let slice_result: &[FpE] = &result; assert_eq!(slice_result, expected_result); @@ -133,10 +186,9 @@ mod tests { let slice_result: &[FpE] = &result; assert_eq!(slice_result, expected_result); } - + #[test] - fn evaluate_and_interpolate_8() { - // + fn interpolation() { let coeff = vec![ FpE::from(1), FpE::from(2), @@ -148,14 +200,39 @@ mod tests { FpE::from(8), ]; - let evals = evaluate_cfft(coeff.clone()); - let factor = FpE::from(8).inv().unwrap(); - let mut new_coeff = interpolate_cfft(evals); - new_coeff = new_coeff.iter() - .map(|coeff| factor * coeff) - .collect(); - assert_eq!(new_coeff, coeff); + + // println!("EVALS: {:?}", evals); + + // EVALS: [ + // FieldElement { value: 885347334 }, -> 0 + // FieldElement { value: 1037382257 }, -> 1 + // FieldElement { value: 714723476 }, -> 2 + // FieldElement { value: 55636419 }, -> 3 + // FieldElement { value: 1262332919 }, -> 4 + // FieldElement { value: 1109642644 }, -> 5 + // FieldElement { value: 1432563561 }, -> 6 + // FieldElement { value: 2092305986 }] -> 7 + + let new_evals = vec![ + FpE::from(885347334), + FpE::from(714723476), + FpE::from(1262332919), + FpE::from(1432563561), + FpE::from(2092305986), + FpE::from(1109642644), + FpE::from(55636419), + FpE::from(1037382257), + ]; + + let new_coeff = interpolate_8(new_evals); + + println!("RES: {:?}", new_coeff); } -} + #[test] + fn cuentas() { + println!("{:?}", FpE::from(32768).inv().unwrap()); // { value: 65536 } + println!("{:?}", FpE::from(2147450879).inv().unwrap()); // { value: 2147418111 } + } +} diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index bf0939bfb..e3a892718 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -12,11 +12,17 @@ pub enum TwiddlesConfig { Interpolation, } #[cfg(feature = "alloc")] -pub fn get_twiddles(domain: Coset, config: TwiddlesConfig) -> Vec>> { +pub fn get_twiddles( + domain: Coset, + config: TwiddlesConfig, +) -> Vec>> { let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - in_place_bit_reverse_permute::>(&mut half_domain_points[..]); + if config == TwiddlesConfig::Evaluation { + in_place_bit_reverse_permute::>(&mut half_domain_points[..]); + } - let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; + let mut twiddles: Vec>> = + vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); @@ -40,6 +46,31 @@ pub fn get_twiddles(domain: Coset, config: TwiddlesConfig) -> Vec Vec>> { + let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + let mut twiddles: Vec>> = + vec![half_domain_points.iter().map(|p| p.y).collect()]; + twiddles.push(half_domain_points.iter().take(1).map(|p| p.x).collect()); + twiddles.iter_mut().for_each(|x| { + FieldElement::::inplace_batch_inverse(x).unwrap(); + }); + twiddles +} + +pub fn get_twiddles_itnerpolation_8(domain: Coset) -> Vec>> { + let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + let mut twiddles: Vec>> = + vec![half_domain_points.iter().map(|p| p.y).collect()]; + twiddles.push(half_domain_points.iter().take(2).map(|p| p.x).collect()); + twiddles.push(vec![ + half_domain_points[0].x.square().double() - FieldElement::::one(), + ]); + twiddles.iter_mut().for_each(|x| { + FieldElement::::inplace_batch_inverse(x).unwrap(); + }); + twiddles +} + #[cfg(test)] mod tests { use super::*; From dc4124e96871c5c9a08e35b811a6f854f8424acd Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 15 Oct 2024 18:51:21 -0300 Subject: [PATCH 44/93] wip --- math/src/circle/cfft.rs | 222 +++++++++++++++++++++++++++++----- math/src/circle/polynomial.rs | 45 ++++--- math/src/circle/twiddles.rs | 14 +-- 3 files changed, 227 insertions(+), 54 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 791a8d7ab..1f680cbb8 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -2,39 +2,80 @@ extern crate alloc; use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; #[cfg(feature = "alloc")] -pub fn inplace_cfft( +pub fn cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { - let mut group_count = 1; - let mut group_size = input.len(); - let mut round = 0; + let log_2_size = input.len().trailing_zeros(); + + (0..log_2_size).for_each(|i| { + let chunk_size = 1 << i + 1; + let half_chunk_size = 1 << i; + input.chunks_mut(chunk_size).for_each(|chunk| { + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { + let temp = *low * twiddles[i as usize][j]; + *low = *hi - temp; + *hi = *hi + temp; + }); + }); + }); +} - while group_count < input.len() { - let round_twiddles = &twiddles[round]; - #[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm - for group in 0..group_count { - let first_in_group = group * group_size; - let first_in_next_group = first_in_group + group_size / 2; - let w = &round_twiddles[group]; // a twiddle factor is used per group +#[cfg(feature = "alloc")] +pub fn icfft( + input: &mut [FieldElement], + twiddles: Vec>>, +) { + let log_2_size = input.len().trailing_zeros(); + + println!("{:?}", twiddles); + + (0..log_2_size).for_each(|i| { + let chunk_size = 1 << log_2_size - i; + let half_chunk_size = chunk_size >> 1; + input.chunks_mut(chunk_size).for_each(|chunk| { + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { + let temp = *hi + *low; + *low = (*hi - *low) * twiddles[i as usize][j]; + *hi = temp; + }); + }); + }); +} - for i in first_in_group..first_in_next_group { - let wi = w * input[i + group_size / 2]; +pub fn order_cfft_result_naive(input: &mut [FieldElement]) -> Vec> { + let mut result = Vec::new(); + let length = input.len(); + for i in (0..length/2) { + result.push(input[i]); + result.push(input[length - i - 1]); + } + result +} - let y0 = input[i] + wi; - let y1 = input[i] - wi; +pub fn order_icfft_input_naive(input: &mut [FieldElement]) -> Vec> { + let mut result = Vec::new(); + (0..input.len()).step_by(2).for_each( |i| { + result.push(input[i]); + }); + (1..input.len()).step_by(2).rev().for_each( |i| { + result.push(input[i]); + }); + result +} - input[i] = y0; - input[i + group_size / 2] = y1; - } - } - group_count *= 2; - group_size /= 2; - round += 1; +pub fn reverse_cfft_index(index: usize, length: usize) -> usize { + if index < (length >> 1) { // index < length / 2 + index << 1 // index * 2 + } else { + (((length - 1) - index) << 1) + 1 } } + pub fn cfft_4( input: &mut [FieldElement], twiddles: Vec>>, @@ -104,19 +145,134 @@ pub fn cfft_8( stage3.into_iter().map(|elem| elem * f).collect() } -pub fn inplace_order_cfft_values(input: &mut [FieldElement]) { - for i in 0..input.len() { - let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros()); - if cfft_index > i { - input.swap(i, cfft_index); + +#[cfg(test)] +mod tests { + use super::*; + type FE = FieldElement; + + #[test] + fn ordering_4() { + let expected_slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + ]; + + let mut slice = [ + FE::from(0), + FE::from(2), + FE::from(3), + FE::from(1), + ]; + + let res = order_cfft_result_naive(&mut slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn ordering() { + let expected_slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let mut slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_cfft_result_naive(&mut slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn reverse_cfft_index_works() { + let mut reversed: Vec = Vec::with_capacity(16); + for i in 0..reversed.capacity() { + reversed.push(reverse_cfft_index(i, reversed.capacity())); } + assert_eq!( + reversed[..], + [0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1] + ); } -} -pub fn reverse_cfft_index(index: usize, log_2_size: u32) -> usize { - let (mut new_index, lsb) = (index >> 1, index & 1); - if (lsb == 1) & (log_2_size > 1) { - new_index = (1 << log_2_size) - new_index - 1; + #[test] + fn from_natural_to_icfft_input_order() { + let mut slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let expected_slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_icfft_input_naive(&mut slice); + + assert_eq!(res, expected_slice) } - new_index.reverse_bits() >> (usize::BITS - log_2_size) + + } diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index e7c2ebaf4..b5f0e1877 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,7 +1,10 @@ -use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use crate::{ + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, + fft::cpu::bit_reversing::in_place_bit_reverse_permute +}; use super::{ - cfft::{cfft_4, cfft_8, inplace_cfft, inplace_order_cfft_values}, + cfft::{cfft, icfft, cfft_4, cfft_8, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{ get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig, @@ -14,14 +17,15 @@ use super::{ pub fn evaluate_cfft( mut coeff: Vec>, ) -> Vec> { + in_place_bit_reverse_permute::>(&mut coeff); let domain_log_2_size: u32 = coeff.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Evaluation; let twiddles = get_twiddles(coset, config); - inplace_cfft(&mut coeff, twiddles); - inplace_order_cfft_values(&mut coeff); - coeff + cfft(&mut coeff, twiddles); + let result = order_cfft_result_naive(&mut coeff); + result } /// Interpolates the 2^n evaluations of a two-variables polynomial on the points of the standard coset of size 2^n. @@ -30,14 +34,15 @@ pub fn evaluate_cfft( pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { + let mut eval_ordered = order_icfft_input_naive(&mut eval); let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Interpolation; let twiddles = get_twiddles(coset, config); - inplace_cfft(&mut eval, twiddles); - inplace_order_cfft_values(&mut eval); - eval + icfft(&mut eval_ordered, twiddles); + let result = order_cfft_result_naive(&mut eval); + result } pub fn interpolate_4( @@ -116,7 +121,7 @@ mod tests { // We create the coset points and evaluate them without the fft. let coset = Coset::new_standard(2); let points = Coset::get_coset_points(&coset); - let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; + let input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; let mut expected_result: Vec = Vec::new(); for point in points { let point_eval = evaluate_poly_4(&input, point.x, point.y); @@ -133,7 +138,7 @@ mod tests { // We create the coset points and evaluate them without the fft. let coset = Coset::new_standard(3); let points = Coset::get_coset_points(&coset); - let mut input = [ + let input = [ FpE::from(1), FpE::from(2), FpE::from(3), @@ -158,7 +163,7 @@ mod tests { fn cfft_evaluation_16_points() { let coset = Coset::new_standard(4); let points = Coset::get_coset_points(&coset); - let mut input = [ + let input = [ FpE::from(1), FpE::from(2), FpE::from(3), @@ -231,8 +236,20 @@ mod tests { } #[test] - fn cuentas() { - println!("{:?}", FpE::from(32768).inv().unwrap()); // { value: 65536 } - println!("{:?}", FpE::from(2147450879).inv().unwrap()); // { value: 2147418111 } + fn evaluate_and_interpolate() { + let coeff = vec![ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); } } diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index e3a892718..abe48581b 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -16,32 +16,32 @@ pub fn get_twiddles( domain: Coset, config: TwiddlesConfig, ) -> Vec>> { - let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - if config == TwiddlesConfig::Evaluation { - in_place_bit_reverse_permute::>(&mut half_domain_points[..]); - } + let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { - twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); + twiddles.push(half_domain_points.iter().take(half_domain_points.len() / 2 ).map(|p| p.x).collect()); for _ in 0..(domain.log_2_size - 2) { let prev = twiddles.last().unwrap(); let cur = prev .iter() - .step_by(2) + .take(prev.len() / 2 ) .map(|x| x.square().double() - FieldElement::::one()) .collect(); twiddles.push(cur); } } - twiddles.reverse(); if config == TwiddlesConfig::Interpolation { + // For the interpolation, we need to take the inverse element of each twiddle in the default order. twiddles.iter_mut().for_each(|x| { FieldElement::::inplace_batch_inverse(x).unwrap(); }); + } else { + // For the evaluation, we need the vector of twiddles but in the inverse order. + twiddles.reverse(); } twiddles } From 56dae17740369069a9f1f714974a1e4188723ec1 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 16 Oct 2024 10:46:28 -0300 Subject: [PATCH 45/93] evaluation and interpolation working --- math/src/circle/cfft.rs | 2 ++ math/src/circle/polynomial.rs | 13 ++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 1f680cbb8..7865b1b67 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -46,6 +46,7 @@ pub fn icfft( }); } +// From [0, 2, 4, 6, 7, 5, 3, 1] to [0, 1, 2, 3, 4, 5, 6, 7] pub fn order_cfft_result_naive(input: &mut [FieldElement]) -> Vec> { let mut result = Vec::new(); let length = input.len(); @@ -56,6 +57,7 @@ pub fn order_cfft_result_naive(input: &mut [FieldElement]) -> V result } +// From [0, 1, 2, 3, 4, 5, 6, 7] to [0, 2, 4, 6, 7, 5, 3, 1] pub fn order_icfft_input_naive(input: &mut [FieldElement]) -> Vec> { let mut result = Vec::new(); (0..input.len()).step_by(2).for_each( |i| { diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index b5f0e1877..bf4da993c 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,10 +1,10 @@ use crate::{ + fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, - fft::cpu::bit_reversing::in_place_bit_reverse_permute }; use super::{ - cfft::{cfft, icfft, cfft_4, cfft_8, order_cfft_result_naive, order_icfft_input_naive}, + cfft::{cfft, cfft_4, cfft_8, icfft, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{ get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig, @@ -34,15 +34,18 @@ pub fn evaluate_cfft( pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { - let mut eval_ordered = order_icfft_input_naive(&mut eval); + let mut eval_ordered = order_icfft_input_naive(&mut eval); let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Interpolation; let twiddles = get_twiddles(coset, config); icfft(&mut eval_ordered, twiddles); - let result = order_cfft_result_naive(&mut eval); - result + in_place_bit_reverse_permute::>(&mut eval_ordered); + let factor = (FieldElement::::from(eval.len() as u64)) + .inv() + .unwrap(); + eval_ordered.iter().map(|coef| coef * factor).collect() } pub fn interpolate_4( From 2546ca2d06a56ac1da2f463f9798e469976b3bf0 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 16 Oct 2024 13:35:11 -0300 Subject: [PATCH 46/93] add tests and comments --- math/src/circle/cfft.rs | 122 ++++++++++------- math/src/circle/polynomial.rs | 238 +++++++++++++++++++++++----------- math/src/circle/twiddles.rs | 28 +++- 3 files changed, 263 insertions(+), 125 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 7865b1b67..94eab38b7 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -2,82 +2,123 @@ extern crate alloc; use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; #[cfg(feature = "alloc")] +/// fft in place algorithm used to evaluate a polynomial of degree 2^n - 1 in 2^n points. +/// Input must be of size 2^n for some n. pub fn cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { + // If the input size is 2^n, then log_2_size is n. let log_2_size = input.len().trailing_zeros(); - + + // The cfft has n layers. (0..log_2_size).for_each(|i| { + // In each layer i we split the current input in chunks of size 2^{i+1}. let chunk_size = 1 << i + 1; let half_chunk_size = 1 << i; input.chunks_mut(chunk_size).for_each(|chunk| { + // We split each chunk in half, calling the first half hi_part and the second hal low_part. let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); - hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { - let temp = *low * twiddles[i as usize][j]; - *low = *hi - temp; - *hi = *hi + temp; - }); + + // We apply the corresponding butterfly for every element j of the high and low part. + hi_part + .into_iter() + .zip(low_part) + .enumerate() + .for_each(|(j, (hi, low))| { + let temp = *low * twiddles[i as usize][j]; + *low = *hi - temp; + *hi = *hi + temp; + }); }); }); -} - +} #[cfg(feature = "alloc")] +/// The inverse fft algorithm used to interpolate 2^n points. +/// Input must be of size 2^n for some n. pub fn icfft( input: &mut [FieldElement], twiddles: Vec>>, ) { + // If the input size is 2^n, then log_2_size is n. let log_2_size = input.len().trailing_zeros(); - - println!("{:?}", twiddles); - + + // The icfft has n layers. (0..log_2_size).for_each(|i| { + // In each layer i we split the current input in chunks of size 2^{n - i}. let chunk_size = 1 << log_2_size - i; let half_chunk_size = chunk_size >> 1; input.chunks_mut(chunk_size).for_each(|chunk| { + // We split each chunk in half, calling the first half hi_part and the second hal low_part. let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); - hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { - let temp = *hi + *low; - *low = (*hi - *low) * twiddles[i as usize][j]; - *hi = temp; - }); + + // We apply the corresponding butterfly for every element j of the high and low part. + hi_part + .into_iter() + .zip(low_part) + .enumerate() + .for_each(|(j, (hi, low))| { + let temp = *hi + *low; + *low = (*hi - *low) * twiddles[i as usize][j]; + *hi = temp; + }); }); - }); -} + }); +} -// From [0, 2, 4, 6, 7, 5, 3, 1] to [0, 1, 2, 3, 4, 5, 6, 7] -pub fn order_cfft_result_naive(input: &mut [FieldElement]) -> Vec> { +/// This function permutes a slice of field elements to order the result of the cfft in the natural way. +/// We call the natural order to [P(x0, y0), P(x1, y1), P(x2, y2), ...], +/// where (x0, y0) is the first point of the corresponding coset. +/// The cfft doesn't return the evaluations in the natural order. +/// For example, if we apply the cfft to 8 coefficients of a polynomial of degree 7 we'll get the evaluations in this order: +/// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)], +/// where the even indices are found first in ascending order and then the odd indices in descending order. +/// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. +pub fn order_cfft_result_naive( + input: &mut [FieldElement], +) -> Vec> { let mut result = Vec::new(); let length = input.len(); - for i in (0..length/2) { - result.push(input[i]); - result.push(input[length - i - 1]); + for i in 0..length / 2 { + result.push(input[i]); // We push the left index. + result.push(input[length - i - 1]); // We push the right index. } result } -// From [0, 1, 2, 3, 4, 5, 6, 7] to [0, 2, 4, 6, 7, 5, 3, 1] -pub fn order_icfft_input_naive(input: &mut [FieldElement]) -> Vec> { +/// This function permutes a slice of field elements to order the input of the icfft in a specific way. +/// For example, if we want to interpolate 8 points we should input them in the icfft in this order: +/// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)], +/// where the even indices are found first in ascending order and then the odd indices in descending order. +/// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1]. +pub fn order_icfft_input_naive( + input: &mut [FieldElement], +) -> Vec> { let mut result = Vec::new(); - (0..input.len()).step_by(2).for_each( |i| { + + // We push the even indices. + (0..input.len()).step_by(2).for_each(|i| { result.push(input[i]); }); - (1..input.len()).step_by(2).rev().for_each( |i| { + + // We push the odd indices. + (1..input.len()).step_by(2).rev().for_each(|i| { result.push(input[i]); }); result } +// We are not using this fucntion. pub fn reverse_cfft_index(index: usize, length: usize) -> usize { - if index < (length >> 1) { // index < length / 2 + if index < (length >> 1) { + // index < length / 2 index << 1 // index * 2 } else { (((length - 1) - index) << 1) + 1 } } - pub fn cfft_4( input: &mut [FieldElement], twiddles: Vec>>, @@ -147,27 +188,16 @@ pub fn cfft_8( stage3.into_iter().map(|elem| elem * f).collect() } - #[cfg(test)] mod tests { use super::*; type FE = FieldElement; #[test] - fn ordering_4() { - let expected_slice = [ - FE::from(0), - FE::from(1), - FE::from(2), - FE::from(3), - ]; + fn ordering_cfft_result_works_for_4_points() { + let expected_slice = [FE::from(0), FE::from(1), FE::from(2), FE::from(3)]; - let mut slice = [ - FE::from(0), - FE::from(2), - FE::from(3), - FE::from(1), - ]; + let mut slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)]; let res = order_cfft_result_naive(&mut slice); @@ -175,7 +205,7 @@ mod tests { } #[test] - fn ordering() { + fn ordering_cfft_result_works_for_16_points() { let expected_slice = [ FE::from(0), FE::from(1), @@ -232,7 +262,7 @@ mod tests { } #[test] - fn from_natural_to_icfft_input_order() { + fn from_natural_to_icfft_input_order_works() { let mut slice = [ FE::from(0), FE::from(1), @@ -275,6 +305,4 @@ mod tests { assert_eq!(res, expected_slice) } - - } diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index bf4da993c..507434fab 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -11,37 +11,48 @@ use super::{ }, }; -/// Given the 2^n coefficients of a two-variables polynomial in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} +/// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomianl on the points of the standard coset of size 2^n. /// Note that coeff has to be a vector with length a power of two 2^n. pub fn evaluate_cfft( mut coeff: Vec>, ) -> Vec> { - in_place_bit_reverse_permute::>(&mut coeff); + // We get the twiddles for the Evaluation. let domain_log_2_size: u32 = coeff.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Evaluation; let twiddles = get_twiddles(coset, config); + // For our algorithm to work, we must give as input the coefficients in bit reverse order. + in_place_bit_reverse_permute::>(&mut coeff); cfft(&mut coeff, twiddles); + + // The cfft returns the evaluations in a certain order, so we permute them to get the natural order. let result = order_cfft_result_naive(&mut coeff); result } -/// Interpolates the 2^n evaluations of a two-variables polynomial on the points of the standard coset of size 2^n. +/// Interpolates the 2^n evaluations of a two-variables polynomial of degree 2^n - 1 on the points of the standard coset of size 2^n. /// As a result we obtain the coefficients of the polynomial in the basis: {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// Note that eval has to be a vector of length a power of two 2^n. pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { - let mut eval_ordered = order_icfft_input_naive(&mut eval); + // We get the twiddles for the interpolation. let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Interpolation; let twiddles = get_twiddles(coset, config); + // For our algorithm to work, we must give as input the evaluations ordered in a certain way. + let mut eval_ordered = order_icfft_input_naive(&mut eval); icfft(&mut eval_ordered, twiddles); + + // The icfft returns the polynomial coefficients in bit reverse order. So we premute it to get the natural order. in_place_bit_reverse_permute::>(&mut eval_ordered); + + // The icfft returns all the coefficients multiplied by 2^n, the length of the evaluations. + // So we multiply every element that outputs the icfft byt the inverse of 2^n to get the actual coefficients. let factor = (FieldElement::::from(eval.len() as u64)) .inv() .unwrap(); @@ -74,30 +85,33 @@ pub fn interpolate_8( mod tests { use super::*; use crate::circle::cosets::Coset; - type FpE = FieldElement; + type FE = FieldElement; - fn evaluate_poly_4(coef: &[FpE; 4], x: FpE, y: FpE) -> FpE { + /// Naive evaluation of a polynomial of degree 3. + fn evaluate_poly_4(coef: &[FE; 4], x: FE, y: FE) -> FE { coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y } - fn evaluate_poly_8(coef: &[FpE; 8], x: FpE, y: FpE) -> FpE { + /// Naive evaluation of a polynomial of degree 7. + fn evaluate_poly_8(coef: &[FE; 8], x: FE, y: FE) -> FE { coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y - + coef[4] * (x.square().double() - FpE::one()) - + coef[5] * (x.square().double() - FpE::one()) * y + + coef[4] * (x.square().double() - FE::one()) + + coef[5] * (x.square().double() - FE::one()) * y + coef[6] * ((x.square() * x).double() - x) + coef[7] * ((x.square() * x).double() - x) * y } - fn evaluate_poly_16(coef: &[FpE; 16], x: FpE, y: FpE) -> FpE { + /// Naive evaluation of a polynomial of degree 15. + fn evaluate_poly_16(coef: &[FE; 16], x: FE, y: FE) -> FE { let mut a = x; let mut v = Vec::new(); - v.push(FpE::one()); + v.push(FE::one()); v.push(x); for _ in 2..4 { - a = a.square().double() - FpE::one(); + a = a.square().double() - FE::one(); v.push(a); } @@ -120,92 +134,108 @@ mod tests { } #[test] + /// cfft evaluation equals naive evaluation. fn cfft_evaluation_4_points() { - // We create the coset points and evaluate them without the fft. + // We define the coefficients of a polynomial of degree 3. + let input = [FE::from(1), FE::from(2), FE::from(3), FE::from(4)]; + + // We create the coset points and evaluate the polynomial with the naive function. let coset = Coset::new_standard(2); let points = Coset::get_coset_points(&coset); - let input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; - let mut expected_result: Vec = Vec::new(); + let mut expected_result: Vec = Vec::new(); for point in points { let point_eval = evaluate_poly_4(&input, point.x, point.y); expected_result.push(point_eval); } + // We evaluate the polynomial using now the cfft. let result = evaluate_cfft(input.to_vec()); - let slice_result: &[FpE] = &result; + let slice_result: &[FE] = &result; + assert_eq!(slice_result, expected_result); } #[test] + /// cfft evaluation equals naive evaluation. fn cfft_evaluation_8_points() { + // We define the coefficients of a polynomial of degree 7. + let input = [ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + ]; + // We create the coset points and evaluate them without the fft. let coset = Coset::new_standard(3); let points = Coset::get_coset_points(&coset); - let input = [ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), - ]; - let mut expected_result: Vec = Vec::new(); + let mut expected_result: Vec = Vec::new(); for point in points { let point_eval = evaluate_poly_8(&input, point.x, point.y); expected_result.push(point_eval); } + // We evaluate the polynomial using now the cfft. let result = evaluate_cfft(input.to_vec()); - let slice_result: &[FpE] = &result; + let slice_result: &[FE] = &result; + assert_eq!(slice_result, expected_result); } #[test] + /// cfft evaluation equals naive evaluation. fn cfft_evaluation_16_points() { - let coset = Coset::new_standard(4); - let points = Coset::get_coset_points(&coset); + // We define the coefficients of a polynomial of degree 15. let input = [ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), - FpE::from(9), - FpE::from(10), - FpE::from(11), - FpE::from(12), - FpE::from(13), - FpE::from(14), - FpE::from(15), - FpE::from(16), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + FE::from(16), ]; - let mut expected_result: Vec = Vec::new(); + + // We create the coset points and evaluate them without the fft. + let coset = Coset::new_standard(4); + let points = Coset::get_coset_points(&coset); + let mut expected_result: Vec = Vec::new(); for point in points { let point_eval = evaluate_poly_16(&input, point.x, point.y); expected_result.push(point_eval); } + // We evaluate the polynomial using now the cfft. let result = evaluate_cfft(input.to_vec()); - let slice_result: &[FpE] = &result; + let slice_result: &[FE] = &result; + assert_eq!(slice_result, expected_result); } #[test] fn interpolation() { let coeff = vec![ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), ]; let evals = evaluate_cfft(coeff.clone()); @@ -223,32 +253,92 @@ mod tests { // FieldElement { value: 2092305986 }] -> 7 let new_evals = vec![ - FpE::from(885347334), - FpE::from(714723476), - FpE::from(1262332919), - FpE::from(1432563561), - FpE::from(2092305986), - FpE::from(1109642644), - FpE::from(55636419), - FpE::from(1037382257), + FE::from(885347334), + FE::from(714723476), + FE::from(1262332919), + FE::from(1432563561), + FE::from(2092305986), + FE::from(1109642644), + FE::from(55636419), + FE::from(1037382257), ]; let new_coeff = interpolate_8(new_evals); + } + + #[test] + fn evaluate_and_interpolate_8_points_is_identity() { + // We define the 8 coefficients of a polynomial of degree 7. + let coeff = vec![ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); - println!("RES: {:?}", new_coeff); + assert_eq!(coeff, new_coeff); + } + + #[test] + fn evaluate_and_interpolate_8_other_points() { + let coeff = vec![ + FE::from(2147483650), + FE::from(147483647), + FE::from(2147483700), + FE::from(2147483647), + FE::from(3147483647), + FE::from(4147483647), + FE::from(2147483640), + FE::from(5147483647), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); } #[test] - fn evaluate_and_interpolate() { + fn evaluate_and_interpolate_32_points() { + // We define 32 coefficients of a polynomial of degree 31. let coeff = vec![ - FpE::from(1), - FpE::from(2), - FpE::from(3), - FpE::from(4), - FpE::from(5), - FpE::from(6), - FpE::from(7), - FpE::from(8), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + FE::from(16), + FE::from(17), + FE::from(18), + FE::from(19), + FE::from(20), + FE::from(21), + FE::from(22), + FE::from(23), + FE::from(24), + FE::from(25), + FE::from(26), + FE::from(27), + FE::from(28), + FE::from(29), + FE::from(30), + FE::from(31), + FE::from(32), ]; let evals = evaluate_cfft(coeff.clone()); let new_coeff = interpolate_cfft(evals); diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index abe48581b..a38ef6ac6 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -16,18 +16,28 @@ pub fn get_twiddles( domain: Coset, config: TwiddlesConfig, ) -> Vec>> { + // We first take the half coset. let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + // The first set of twiddles are all the y coordinates of the half coset. let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { - twiddles.push(half_domain_points.iter().take(half_domain_points.len() / 2 ).map(|p| p.x).collect()); + // The second set of twiddles are the x coordinates of the first half of the half coset. + twiddles.push( + half_domain_points + .iter() + .take(half_domain_points.len() / 2) + .map(|p| p.x) + .collect(), + ); for _ in 0..(domain.log_2_size - 2) { + // The rest of the sets of twiddles are the "square" of the x coordinates of the first half of the previous set. let prev = twiddles.last().unwrap(); let cur = prev .iter() - .take(prev.len() / 2 ) + .take(prev.len() / 2) .map(|x| x.square().double() - FieldElement::::one()) .collect(); twiddles.push(cur); @@ -40,7 +50,7 @@ pub fn get_twiddles( FieldElement::::inplace_batch_inverse(x).unwrap(); }); } else { - // For the evaluation, we need the vector of twiddles but in the inverse order. + // For the evaluation, we need reverse the order of the vector of twiddles. twiddles.reverse(); } twiddles @@ -76,10 +86,20 @@ mod tests { use super::*; #[test] - fn twiddles_vectors_lenght() { + fn evaluation_twiddles_vectors_length_is_correct() { let domain = Coset::new_standard(3); let config = TwiddlesConfig::Evaluation; let twiddles = get_twiddles(domain, config); + for i in 0..twiddles.len() - 1 { + assert_eq!(2 * twiddles[i].len(), twiddles[i + 1].len()) + } + } + + #[test] + fn interpolation_twiddles_vectors_length_is_correct() { + let domain = Coset::new_standard(3); + let config = TwiddlesConfig::Interpolation; + let twiddles = get_twiddles(domain, config); for i in 0..twiddles.len() - 1 { assert_eq!(twiddles[i].len(), 2 * twiddles[i + 1].len()) } From da5fae722e6eadfdfd20d3df3f928c91365f223b Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 17 Oct 2024 11:21:38 -0300 Subject: [PATCH 47/93] clippy --- math/src/circle/cfft.rs | 79 +++-------------------------------- math/src/circle/polynomial.rs | 70 ++----------------------------- math/src/circle/twiddles.rs | 28 +------------ 3 files changed, 9 insertions(+), 168 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 94eab38b7..d20c7346a 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -14,7 +14,7 @@ pub fn cfft( // The cfft has n layers. (0..log_2_size).for_each(|i| { // In each layer i we split the current input in chunks of size 2^{i+1}. - let chunk_size = 1 << i + 1; + let chunk_size = 1 << (i + 1); let half_chunk_size = 1 << i; input.chunks_mut(chunk_size).for_each(|chunk| { // We split each chunk in half, calling the first half hi_part and the second hal low_part. @@ -22,13 +22,13 @@ pub fn cfft( // We apply the corresponding butterfly for every element j of the high and low part. hi_part - .into_iter() + .iter_mut() .zip(low_part) .enumerate() .for_each(|(j, (hi, low))| { let temp = *low * twiddles[i as usize][j]; *low = *hi - temp; - *hi = *hi + temp; + *hi += temp }); }); }); @@ -47,7 +47,7 @@ pub fn icfft( // The icfft has n layers. (0..log_2_size).for_each(|i| { // In each layer i we split the current input in chunks of size 2^{n - i}. - let chunk_size = 1 << log_2_size - i; + let chunk_size = 1 << (log_2_size - i); let half_chunk_size = chunk_size >> 1; input.chunks_mut(chunk_size).for_each(|chunk| { // We split each chunk in half, calling the first half hi_part and the second hal low_part. @@ -55,7 +55,7 @@ pub fn icfft( // We apply the corresponding butterfly for every element j of the high and low part. hi_part - .into_iter() + .iter_mut() .zip(low_part) .enumerate() .for_each(|(j, (hi, low))| { @@ -119,75 +119,6 @@ pub fn reverse_cfft_index(index: usize, length: usize) -> usize { } } -pub fn cfft_4( - input: &mut [FieldElement], - twiddles: Vec>>, -) -> Vec> { - let mut stage1: Vec> = Vec::with_capacity(4); - - stage1.push(input[0] + input[1]); - stage1.push((input[0] - input[1]) * twiddles[0][0]); - - stage1.push(input[2] + input[3]); - stage1.push((input[2] - input[3]) * twiddles[0][1]); - - let mut stage2: Vec> = Vec::with_capacity(4); - - stage2.push(stage1[0] + stage1[2]); - stage2.push(stage1[1] + stage1[3]); - - stage2.push((stage1[0] - stage1[2]) * twiddles[1][0]); - stage2.push((stage1[1] - stage1[3]) * twiddles[1][0]); - - let f = FieldElement::::from(4).inv().unwrap(); - stage2.into_iter().map(|elem| elem * f).collect() -} - -pub fn cfft_8( - input: &mut [FieldElement], - twiddles: Vec>>, -) -> Vec> { - let mut stage1: Vec> = Vec::with_capacity(8); - - stage1.push(input[0] + input[4]); - stage1.push(input[1] + input[5]); - stage1.push(input[2] + input[6]); - stage1.push(input[3] + input[7]); - stage1.push((input[0] - input[4]) * twiddles[0][0]); - stage1.push((input[1] - input[5]) * twiddles[0][1]); - stage1.push((input[2] - input[6]) * twiddles[0][2]); - stage1.push((input[3] - input[7]) * twiddles[0][3]); - - let mut stage2: Vec> = Vec::with_capacity(8); - - stage2.push(stage1[0] + stage1[2]); - stage2.push(stage1[1] + stage1[3]); - stage2.push((stage1[0] - stage1[2]) * twiddles[1][0]); - stage2.push((stage1[1] - stage1[3]) * twiddles[1][1]); - - stage2.push(stage1[4] + stage1[6]); - stage2.push(stage1[5] + stage1[7]); - stage2.push((stage1[4] - stage1[6]) * twiddles[1][0]); - stage2.push((stage1[5] - stage1[7]) * twiddles[1][1]); - - let mut stage3: Vec> = Vec::with_capacity(8); - - stage3.push(stage2[0] + stage2[1]); - stage3.push((stage2[0] - stage2[1]) * twiddles[2][0]); - - stage3.push(stage2[2] + stage2[3]); - stage3.push((stage2[2] - stage2[3]) * twiddles[2][0]); - - stage3.push(stage2[4] + stage2[5]); - stage3.push((stage2[4] - stage2[5]) * twiddles[2][0]); - - stage3.push(stage2[6] + stage2[7]); - stage3.push((stage2[6] - stage2[7]) * twiddles[2][0]); - - let f = FieldElement::::from(8).inv().unwrap(); - stage3.into_iter().map(|elem| elem * f).collect() -} - #[cfg(test)] mod tests { use super::*; diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 507434fab..e34e35b8a 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -4,10 +4,10 @@ use crate::{ }; use super::{ - cfft::{cfft, cfft_4, cfft_8, icfft, order_cfft_result_naive, order_icfft_input_naive}, + cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{ - get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig, + get_twiddles, TwiddlesConfig, }, }; @@ -28,8 +28,7 @@ pub fn evaluate_cfft( cfft(&mut coeff, twiddles); // The cfft returns the evaluations in a certain order, so we permute them to get the natural order. - let result = order_cfft_result_naive(&mut coeff); - result + order_cfft_result_naive(&mut coeff) } /// Interpolates the 2^n evaluations of a two-variables polynomial of degree 2^n - 1 on the points of the standard coset of size 2^n. @@ -59,28 +58,6 @@ pub fn interpolate_cfft( eval_ordered.iter().map(|coef| coef * factor).collect() } -pub fn interpolate_4( - mut eval: Vec>, -) -> Vec> { - let domain_log_2_size: u32 = eval.len().trailing_zeros(); - let coset = Coset::new_standard(domain_log_2_size); - let twiddles = get_twiddles_itnerpolation_4(coset); - - let res = cfft_4(&mut eval, twiddles); - res -} - -pub fn interpolate_8( - mut eval: Vec>, -) -> Vec> { - let domain_log_2_size: u32 = eval.len().trailing_zeros(); - let coset = Coset::new_standard(domain_log_2_size); - let twiddles = get_twiddles_itnerpolation_8(coset); - - let res = cfft_8(&mut eval, twiddles); - res -} - #[cfg(test)] mod tests { use super::*; @@ -225,47 +202,6 @@ mod tests { assert_eq!(slice_result, expected_result); } - #[test] - fn interpolation() { - let coeff = vec![ - FE::from(1), - FE::from(2), - FE::from(3), - FE::from(4), - FE::from(5), - FE::from(6), - FE::from(7), - FE::from(8), - ]; - - let evals = evaluate_cfft(coeff.clone()); - - // println!("EVALS: {:?}", evals); - - // EVALS: [ - // FieldElement { value: 885347334 }, -> 0 - // FieldElement { value: 1037382257 }, -> 1 - // FieldElement { value: 714723476 }, -> 2 - // FieldElement { value: 55636419 }, -> 3 - // FieldElement { value: 1262332919 }, -> 4 - // FieldElement { value: 1109642644 }, -> 5 - // FieldElement { value: 1432563561 }, -> 6 - // FieldElement { value: 2092305986 }] -> 7 - - let new_evals = vec![ - FE::from(885347334), - FE::from(714723476), - FE::from(1262332919), - FE::from(1432563561), - FE::from(2092305986), - FE::from(1109642644), - FE::from(55636419), - FE::from(1037382257), - ]; - - let new_coeff = interpolate_8(new_evals); - } - #[test] fn evaluate_and_interpolate_8_points_is_identity() { // We define the 8 coefficients of a polynomial of degree 7. diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index a38ef6ac6..1d0351388 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -1,7 +1,6 @@ extern crate alloc; use crate::{ - circle::{cosets::Coset, point::CirclePoint}, - fft::cpu::bit_reversing::in_place_bit_reverse_permute, + circle::cosets::Coset, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; use alloc::vec::Vec; @@ -56,31 +55,6 @@ pub fn get_twiddles( twiddles } -pub fn get_twiddles_itnerpolation_4(domain: Coset) -> Vec>> { - let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - let mut twiddles: Vec>> = - vec![half_domain_points.iter().map(|p| p.y).collect()]; - twiddles.push(half_domain_points.iter().take(1).map(|p| p.x).collect()); - twiddles.iter_mut().for_each(|x| { - FieldElement::::inplace_batch_inverse(x).unwrap(); - }); - twiddles -} - -pub fn get_twiddles_itnerpolation_8(domain: Coset) -> Vec>> { - let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - let mut twiddles: Vec>> = - vec![half_domain_points.iter().map(|p| p.y).collect()]; - twiddles.push(half_domain_points.iter().take(2).map(|p| p.x).collect()); - twiddles.push(vec![ - half_domain_points[0].x.square().double() - FieldElement::::one(), - ]); - twiddles.iter_mut().for_each(|x| { - FieldElement::::inplace_batch_inverse(x).unwrap(); - }); - twiddles -} - #[cfg(test)] mod tests { use super::*; From 28ab71fb55ae9747d718db7c72119f673c47e836 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 17 Oct 2024 11:25:49 -0300 Subject: [PATCH 48/93] fmt --- math/src/circle/mod.rs | 2 +- math/src/circle/polynomial.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index b76831d65..ac576194f 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -2,5 +2,5 @@ pub mod cfft; pub mod cosets; pub mod errors; pub mod point; +pub mod polynomial; pub mod twiddles; -pub mod polynomial; \ No newline at end of file diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index e34e35b8a..93baa97ca 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -6,9 +6,7 @@ use crate::{ use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, - twiddles::{ - get_twiddles, TwiddlesConfig, - }, + twiddles::{get_twiddles, TwiddlesConfig}, }; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} From e0c666df0a59f80d50ff5e3f0d79ba510151a939 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 17 Oct 2024 11:47:16 -0300 Subject: [PATCH 49/93] remove unused functions --- math/src/circle/cfft.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index d20c7346a..26509bbe6 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -109,16 +109,6 @@ pub fn order_icfft_input_naive( result } -// We are not using this fucntion. -pub fn reverse_cfft_index(index: usize, length: usize) -> usize { - if index < (length >> 1) { - // index < length / 2 - index << 1 // index * 2 - } else { - (((length - 1) - index) << 1) + 1 - } -} - #[cfg(test)] mod tests { use super::*; From 3073470fbef9a61908e58edc09c2f59acaa72bb1 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 17 Oct 2024 12:23:53 -0300 Subject: [PATCH 50/93] add comment --- math/src/circle/cfft.rs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 26509bbe6..424a8a982 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -75,6 +75,7 @@ pub fn icfft( /// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)], /// where the even indices are found first in ascending order and then the odd indices in descending order. /// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. +/// NOTE: This can be optimized by performing in-place value swapping (WIP). pub fn order_cfft_result_naive( input: &mut [FieldElement], ) -> Vec> { @@ -92,6 +93,7 @@ pub fn order_cfft_result_naive( /// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)], /// where the even indices are found first in ascending order and then the odd indices in descending order. /// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1]. +/// NOTE: This can be optimized by performing in-place value swapping (WIP). pub fn order_icfft_input_naive( input: &mut [FieldElement], ) -> Vec> { @@ -170,18 +172,6 @@ mod tests { assert_eq!(res, expected_slice) } - #[test] - fn reverse_cfft_index_works() { - let mut reversed: Vec = Vec::with_capacity(16); - for i in 0..reversed.capacity() { - reversed.push(reverse_cfft_index(i, reversed.capacity())); - } - assert_eq!( - reversed[..], - [0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1] - ); - } - #[test] fn from_natural_to_icfft_input_order_works() { let mut slice = [ From 37aae9e4bc339e55eae288839189fb15fd3dbbce Mon Sep 17 00:00:00 2001 From: Nicole Graus Date: Wed, 23 Oct 2024 11:24:03 -0300 Subject: [PATCH 51/93] Update math/src/circle/polynomial.rs Co-authored-by: Ivan Litteri <67517699+ilitteri@users.noreply.github.com> --- math/src/circle/polynomial.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 93baa97ca..953b75834 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -10,7 +10,7 @@ use super::{ }; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} -/// returns the evaluation of the polynomianl on the points of the standard coset of size 2^n. +/// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. /// Note that coeff has to be a vector with length a power of two 2^n. pub fn evaluate_cfft( mut coeff: Vec>, From 4436e28dd7c3b06ef4fda1aa8135eac03a4d513a Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 14:15:34 -0300 Subject: [PATCH 52/93] change generator and order CirclePoint functions as constants --- math/src/circle/cosets.rs | 4 +- math/src/circle/point.rs | 70 ++++++++----------- .../src/field/fields/mersenne31/extensions.rs | 9 ++- 3 files changed, 38 insertions(+), 45 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 0da5f204b..171350dca 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -25,7 +25,7 @@ impl Coset { /// Returns g_n, the generator of the subgroup of order n = 2^log_2_size. pub fn get_generator(&self) -> CirclePoint { - CirclePoint::generator().repeated_double(31 - self.log_2_size) + CirclePoint::GENERATOR.repeated_double(31 - self.log_2_size) } /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + @@ -82,7 +82,7 @@ mod tests { #[test] fn coset_generator_has_right_order() { - let coset = Coset::new(2, CirclePoint::generator().scalar_mul(3)); + let coset = Coset::new(2, CirclePoint::GENERATOR.scalar_mul(3)); let generator_n = coset.get_generator(); assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); } diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 92b6200f5..5ac47fbac 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -15,17 +15,21 @@ pub struct CirclePoint { pub trait HasCircleParams { type FE; - fn circle_generator() -> (FieldElement, FieldElement); + /// Coordinate x of the generator of the circle group. + const CIRCLE_GENERATOR_X: FieldElement; + + /// Coordinate y of the generator of the circle group. + const CIRCLE_GENERATOR_Y: FieldElement; + const ORDER: u128; } impl HasCircleParams for Mersenne31Field { type FE = FieldElement; - // This could be a constant instead of a function - fn circle_generator() -> (Self::FE, Self::FE) { - (Self::FE::from(&2), Self::FE::from(&1268011823)) - } + const CIRCLE_GENERATOR_X: Self::FE = Self::FE::const_from_raw(2); + + const CIRCLE_GENERATOR_Y: Self::FE = Self::FE::const_from_raw(1268011823); /// ORDER = 2^31 const ORDER: u128 = 2147483648; @@ -34,26 +38,11 @@ impl HasCircleParams for Mersenne31Field { impl HasCircleParams for Degree4ExtensionField { type FE = FieldElement; - // This could be a constant instead of a function - fn circle_generator() -> ( - FieldElement, - FieldElement, - ) { - ( - Degree4ExtensionField::from_coeffcients( - FieldElement::::one(), - FieldElement::::zero(), - FieldElement::::from(&478637715), - FieldElement::::from(&513582971), - ), - Degree4ExtensionField::from_coeffcients( - FieldElement::::from(992285211), - FieldElement::::from(649143431), - FieldElement::::from(&740191619), - FieldElement::::from(&1186584352), - ), - ) - } + const CIRCLE_GENERATOR_X: Self::FE = + Degree4ExtensionField::const_from_coefficients(1, 0, 478637715, 513582971); + + const CIRCLE_GENERATOR_Y: Self::FE = + Degree4ExtensionField::const_from_coefficients(992285211, 649143431, 740191619, 1186584352); /// ORDER = (2^31 - 1)^4 - 1 const ORDER: u128 = 21267647892944572736998860269687930880; @@ -62,7 +51,7 @@ impl HasCircleParams for Degree4ExtensionField { impl> CirclePoint { pub fn new(x: FieldElement, y: FieldElement) -> Result { if x.square() + y.square() == FieldElement::one() { - Ok(CirclePoint { x, y }) + Ok(Self { x, y }) } else { Err(CircleError::InvalidValue) } @@ -116,7 +105,7 @@ impl> CirclePoint { } /// Computes the inverse of the point. - /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation /// because the norm of every point in the circle is one. pub fn conjugate(self) -> Self { Self { @@ -136,19 +125,18 @@ impl> CirclePoint { a.x == b.x && a.y == b.y } - pub fn generator() -> Self { - CirclePoint::new(F::circle_generator().0, F::circle_generator().1).unwrap() - } + pub const GENERATOR: Self = Self { + x: F::CIRCLE_GENERATOR_X, + y: F::CIRCLE_GENERATOR_Y, + }; /// Returns the generator of the subgroup of order n = 2^log_2_size. /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { - Self::generator().repeated_double(31 - log_2_size) + Self::GENERATOR.repeated_double(31 - log_2_size) } - pub fn group_order() -> u128 { - F::ORDER - } + pub const ORDER: u128 = F::ORDER; } impl> PartialEq for CirclePoint { @@ -216,39 +204,39 @@ mod tests { #[test] fn generator_plus_zero_is_generator() { - let g = G::generator(); + let g = G::GENERATOR; let zero = G::zero(); assert_eq!(g.clone() + zero, g) } #[test] fn double_equals_mul_two() { - let g = G::generator(); + let g = G::GENERATOR; assert_eq!(g.clone().double(), G::scalar_mul(g, 2)) } #[test] fn mul_eight_equals_double_three_times() { - let g = G::generator(); + let g = G::GENERATOR; assert_eq!(g.clone().repeated_double(3), G::scalar_mul(g, 8)) } #[test] fn generator_g1_has_order_two_pow_31() { - let g = G::generator(); + let g = G::GENERATOR; let n = 31; assert_eq!(g.repeated_double(n), G::zero()) } #[test] fn generator_g4_has_the_order_of_the_group() { - let g = G4::generator(); - assert_eq!(g.scalar_mul(G4::group_order()), G4::zero()) + let g = G4::GENERATOR; + assert_eq!(g.scalar_mul(G4::ORDER), G4::zero()) } #[test] fn conjugation_is_inverse_operation() { - let g = G::generator(); + let g = G::GENERATOR; assert_eq!(g.clone() + g.conjugate(), G::zero()) } diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs index 2ec853ec0..d6368d8b1 100644 --- a/math/src/field/fields/mersenne31/extensions.rs +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -1,3 +1,5 @@ +use core::num::FpCategory; + use super::field::Mersenne31Field; use crate::field::{ element::FieldElement, @@ -138,8 +140,11 @@ impl IsSubFieldOf for Mersenne31Field { pub struct Degree4ExtensionField; impl Degree4ExtensionField { - pub fn from_coeffcients(a: FpE, b: FpE, c: FpE, d: FpE) -> Fp4E { - Fp4E::new([Fp2E::new([a, b]), Fp2E::new([c, d])]) + pub const fn const_from_coefficients(a: u32, b: u32, c: u32, d: u32) -> Fp4E { + Fp4E::const_from_raw([ + Fp2E::const_from_raw([FpE::const_from_raw(a), FpE::const_from_raw(b)]), + Fp2E::const_from_raw([FpE::const_from_raw(c), FpE::const_from_raw(d)]), + ]) } } From 835a56508185549884d979ed211102aedd4b4378 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 14:45:04 -0300 Subject: [PATCH 53/93] impl eq as PartialEq --- math/src/circle/point.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 5ac47fbac..0f1db689e 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -121,10 +121,6 @@ impl> CirclePoint { } } - pub fn eq(a: Self, b: Self) -> bool { - a.x == b.x && a.y == b.y - } - pub const GENERATOR: Self = Self { x: F::CIRCLE_GENERATOR_X, y: F::CIRCLE_GENERATOR_Y, @@ -141,7 +137,7 @@ impl> CirclePoint { impl> PartialEq for CirclePoint { fn eq(&self, other: &Self) -> bool { - CirclePoint::eq(self.clone(), other.clone()) + self.x == other.x && self.y == other.y } } From 2ea3171eaa87df84529555475077e0a35c24775e Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 16:00:11 -0300 Subject: [PATCH 54/93] implement scalar_mul as Mul --- math/src/circle/cosets.rs | 4 +- math/src/circle/point.rs | 81 ++++++++++++++++++++------------------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 171350dca..9b3cbf346 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -53,7 +53,7 @@ impl Coset { let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); let size: u8 = 1 << coset.log_2_size; core::iter::successors(Some(coset.shift.clone()), move |prev| { - Some(prev.clone() + generator_n.clone()) + Some(prev + &generator_n) }) .take(size.into()) .collect() @@ -82,7 +82,7 @@ mod tests { #[test] fn coset_generator_has_right_order() { - let coset = Coset::new(2, CirclePoint::GENERATOR.scalar_mul(3)); + let coset = Coset::new(2, CirclePoint::GENERATOR * 3); let generator_n = coset.get_generator(); assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); } diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 0f1db689e..ac8824d8d 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -4,7 +4,7 @@ use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use core::ops::Add; +use core::ops::{Add, Mul}; #[derive(Debug, Clone)] pub struct CirclePoint { @@ -48,30 +48,31 @@ impl HasCircleParams for Degree4ExtensionField { const ORDER: u128 = 21267647892944572736998860269687930880; } -impl> CirclePoint { - pub fn new(x: FieldElement, y: FieldElement) -> Result { - if x.square() + y.square() == FieldElement::one() { - Ok(Self { x, y }) - } else { - Err(CircleError::InvalidValue) - } +/// Equality between two cricle points. +impl> PartialEq for CirclePoint { + fn eq(&self, other: &Self) -> bool { + self.x == other.x && self.y == other.y } +} - /// Neutral element of the Circle group (with additive notation). - pub fn zero() -> Self { - Self::new(FieldElement::one(), FieldElement::zero()).unwrap() - } +/// Addition (i.e. group operation with additive notation) between two points: +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) +impl> Add for &CirclePoint { + type Output = CirclePoint; - /// Computes (a0, a1) + (b0, b1) = (a0 * b0 - a1 * b1, a0 * b1 + a1 * b0) - #[allow(clippy::should_implement_trait)] - pub fn add(a: Self, b: Self) -> Self { - let x = &a.x * &b.x - &a.y * &b.y; - let y = a.x * b.y + a.y * b.x; + fn add(self, other: Self) -> Self::Output { + let x = &self.x * &other.x - &self.y * &other.y; + let y = &self.x * &other.y + &self.y * &other.x; CirclePoint { x, y } } +} - /// Computes n * (x, y) = (x ,y) + ... + (x, y) n-times. - pub fn scalar_mul(self, mut scalar: u128) -> Self { +/// Multiplication between a point and a scalar (i.e. group operation repeatedly): +/// (x, y) * n = (x ,y) + ... + (x, y) n-times. +impl> Mul for CirclePoint { + type Output = CirclePoint; + + fn mul(self, mut scalar: u128) -> Self { let mut res = Self::zero(); let mut cur = self; loop { @@ -79,12 +80,27 @@ impl> CirclePoint { return res; } if scalar & 1 == 1 { - res = res + cur.clone(); + res = &res + &cur; } cur = cur.double(); scalar >>= 1; } } +} + +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::InvalidValue) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } /// Computes 2(x, y) = (2x^2 - 1, 2xy). pub fn double(self) -> Self { @@ -135,19 +151,6 @@ impl> CirclePoint { pub const ORDER: u128 = F::ORDER; } -impl> PartialEq for CirclePoint { - fn eq(&self, other: &Self) -> bool { - self.x == other.x && self.y == other.y - } -} - -impl> Add for CirclePoint { - type Output = CirclePoint; - fn add(self, other: Self) -> Self { - CirclePoint::add(self, other) - } -} - #[cfg(test)] mod tests { use super::*; @@ -195,26 +198,26 @@ mod tests { fn zero_plus_zero_is_zero() { let a = G::zero(); let b = G::zero(); - assert_eq!(a + b, G::zero()) + assert_eq!(&a + &b, G::zero()) } #[test] fn generator_plus_zero_is_generator() { let g = G::GENERATOR; let zero = G::zero(); - assert_eq!(g.clone() + zero, g) + assert_eq!(&g + &zero, g) } #[test] fn double_equals_mul_two() { let g = G::GENERATOR; - assert_eq!(g.clone().double(), G::scalar_mul(g, 2)) + assert_eq!(g.clone().double(), g * 2) } #[test] fn mul_eight_equals_double_three_times() { let g = G::GENERATOR; - assert_eq!(g.clone().repeated_double(3), G::scalar_mul(g, 8)) + assert_eq!(g.clone().repeated_double(3), g * 8) } #[test] @@ -227,13 +230,13 @@ mod tests { #[test] fn generator_g4_has_the_order_of_the_group() { let g = G4::GENERATOR; - assert_eq!(g.scalar_mul(G4::ORDER), G4::zero()) + assert_eq!(g * G4::ORDER, G4::zero()) } #[test] fn conjugation_is_inverse_operation() { let g = G::GENERATOR; - assert_eq!(g.clone() + g.conjugate(), G::zero()) + assert_eq!(&g.clone() + &g.conjugate(), G::zero()) } #[test] From c9acb7406e699ec8ad2ca467494b22a638b290dd Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 16:28:29 -0300 Subject: [PATCH 55/93] Change error name to a more descriptive one and move it to point.rs --- math/src/circle/errors.rs | 4 ---- math/src/circle/mod.rs | 1 - math/src/circle/point.rs | 16 +++++++++++++--- 3 files changed, 13 insertions(+), 8 deletions(-) delete mode 100644 math/src/circle/errors.rs diff --git a/math/src/circle/errors.rs b/math/src/circle/errors.rs deleted file mode 100644 index 07b63ec70..000000000 --- a/math/src/circle/errors.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[derive(Debug)] -pub enum CircleError { - InvalidValue, -} diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index ac576194f..f5a65721f 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -1,6 +1,5 @@ pub mod cfft; pub mod cosets; -pub mod errors; pub mod point; pub mod polynomial; pub mod twiddles; diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index ac8824d8d..13d97de3b 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,4 +1,3 @@ -use super::errors::CircleError; use crate::field::traits::IsField; use crate::field::{ element::FieldElement, @@ -6,12 +5,23 @@ use crate::field::{ }; use core::ops::{Add, Mul}; +/// Given a Field F, we implement here the Group which consists of all the points (x, y) such as +/// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have +/// additive notation and is as follows: +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) + #[derive(Debug, Clone)] pub struct CirclePoint { pub x: FieldElement, pub y: FieldElement, } +#[derive(Debug)] +pub enum CircleError { + PointDoesntSatisfyCircleEquation, +} + +/// Parameters of the base field that we'll need to define its Circle. pub trait HasCircleParams { type FE; @@ -55,7 +65,7 @@ impl> PartialEq for CirclePoint { } } -/// Addition (i.e. group operation with additive notation) between two points: +/// Addition (i.e. group operation) between two points: /// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) impl> Add for &CirclePoint { type Output = CirclePoint; @@ -93,7 +103,7 @@ impl> CirclePoint { if x.square() + y.square() == FieldElement::one() { Ok(Self { x, y }) } else { - Err(CircleError::InvalidValue) + Err(CircleError::PointDoesntSatisfyCircleEquation) } } From 93fd51566624d13d490367a335f445f8ed06584b Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 16:57:07 -0300 Subject: [PATCH 56/93] fix lint --- math/src/field/fields/mersenne31/extensions.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs index d6368d8b1..69c64f096 100644 --- a/math/src/field/fields/mersenne31/extensions.rs +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -1,5 +1,3 @@ -use core::num::FpCategory; - use super::field::Mersenne31Field; use crate::field::{ element::FieldElement, From 7ae499081f5a9ae8b7db2311c13805c2ed938099 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 17:02:35 -0300 Subject: [PATCH 57/93] fix lint --- math/src/circle/cfft.rs | 1 + math/src/circle/polynomial.rs | 1 + math/src/circle/twiddles.rs | 1 + 3 files changed, 3 insertions(+) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 424a8a982..3fb05ff1d 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -1,5 +1,6 @@ extern crate alloc; use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use alloc::vec::Vec; #[cfg(feature = "alloc")] /// fft in place algorithm used to evaluate a polynomial of degree 2^n - 1 in 2^n points. diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 953b75834..a83e4e1be 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -8,6 +8,7 @@ use super::{ cosets::Coset, twiddles::{get_twiddles, TwiddlesConfig}, }; +use alloc::vec::Vec; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 1d0351388..33693370e 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -3,6 +3,7 @@ use crate::{ circle::cosets::Coset, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +use alloc::vec; use alloc::vec::Vec; #[derive(PartialEq)] From 880131709a08072cba0d6805fb5cca2b0657c078 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 23 Oct 2024 17:32:35 -0300 Subject: [PATCH 58/93] fix some comments --- math/src/circle/cfft.rs | 4 ++-- math/src/circle/cosets.rs | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 3fb05ff1d..3b327a954 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -76,7 +76,7 @@ pub fn icfft( /// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)], /// where the even indices are found first in ascending order and then the odd indices in descending order. /// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. -/// NOTE: This can be optimized by performing in-place value swapping (WIP). +/// TODO: This can be optimized by performing in-place value swapping (WIP). pub fn order_cfft_result_naive( input: &mut [FieldElement], ) -> Vec> { @@ -94,7 +94,7 @@ pub fn order_cfft_result_naive( /// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)], /// where the even indices are found first in ascending order and then the odd indices in descending order. /// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1]. -/// NOTE: This can be optimized by performing in-place value swapping (WIP). +/// TODO: This can be optimized by performing in-place value swapping (WIP). pub fn order_icfft_input_naive( input: &mut [FieldElement], ) -> Vec> { diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 9b3cbf346..709ea76fc 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -3,6 +3,12 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; use alloc::vec::Vec; +/// Given g_n, a generator of the subgroup of the circle of size n, +/// and given a shift, that is a another point of the cirvle, +/// we define the coset shift + which is the set of all the points in +/// plus the shift. +/// For example, if = {p1, p2, p3, p4}, then g_8 + = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. + #[derive(Debug, Clone)] pub struct Coset { // Coset: shift + where n = 2^{log_2_size}. From 8b2b3d28aac2c10661e0eec45313bdbe021e9062 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 24 Oct 2024 11:56:33 -0300 Subject: [PATCH 59/93] add alloc::vec --- math/src/circle/polynomial.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index a83e4e1be..0e6a149ca 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -13,6 +13,7 @@ use alloc::vec::Vec; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. /// Note that coeff has to be a vector with length a power of two 2^n. +#[cfg(feature = "alloc")] pub fn evaluate_cfft( mut coeff: Vec>, ) -> Vec> { @@ -33,6 +34,7 @@ pub fn evaluate_cfft( /// Interpolates the 2^n evaluations of a two-variables polynomial of degree 2^n - 1 on the points of the standard coset of size 2^n. /// As a result we obtain the coefficients of the polynomial in the basis: {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// Note that eval has to be a vector of length a power of two 2^n. +#[cfg(feature = "alloc")] pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { @@ -62,6 +64,7 @@ mod tests { use super::*; use crate::circle::cosets::Coset; type FE = FieldElement; + use alloc::vec; /// Naive evaluation of a polynomial of degree 3. fn evaluate_poly_4(coef: &[FE; 4], x: FE, y: FE) -> FE { From ebc5e9d8b294353a2bca0e6ebb71efef7b5479ba Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 24 Oct 2024 12:10:43 -0300 Subject: [PATCH 60/93] fix no-std --- math/src/circle/cosets.rs | 1 - math/src/circle/polynomial.rs | 3 +-- math/src/circle/twiddles.rs | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 709ea76fc..1448ac913 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,7 +1,6 @@ extern crate alloc; use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; -use alloc::vec::Vec; /// Given g_n, a generator of the subgroup of the circle of size n, /// and given a shift, that is a another point of the cirvle, diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 0e6a149ca..a626bc2bd 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -2,13 +2,12 @@ use crate::{ fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; - +#[cfg(feature = "alloc")] use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{get_twiddles, TwiddlesConfig}, }; -use alloc::vec::Vec; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 33693370e..4c7a8b115 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -3,8 +3,6 @@ use crate::{ circle::cosets::Coset, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; -use alloc::vec; -use alloc::vec::Vec; #[derive(PartialEq)] pub enum TwiddlesConfig { From f3028e5506f5fcce43b931324afc65325f3c22eb Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 24 Oct 2024 12:12:21 -0300 Subject: [PATCH 61/93] cargo fmt --- math/src/circle/polynomial.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index a626bc2bd..a6692216a 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,13 +1,13 @@ -use crate::{ - fft::cpu::bit_reversing::in_place_bit_reverse_permute, - field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, -}; #[cfg(feature = "alloc")] use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{get_twiddles, TwiddlesConfig}, }; +use crate::{ + fft::cpu::bit_reversing::in_place_bit_reverse_permute, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. From b645a7b534673d0541d6b3cd3bcb8ad141c9dfe8 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 24 Oct 2024 12:19:37 -0300 Subject: [PATCH 62/93] fix no-std --- math/src/circle/cosets.rs | 1 + math/src/circle/polynomial.rs | 1 + math/src/circle/twiddles.rs | 1 + 3 files changed, 3 insertions(+) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 1448ac913..709ea76fc 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -1,6 +1,7 @@ extern crate alloc; use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; +use alloc::vec::Vec; /// Given g_n, a generator of the subgroup of the circle of size n, /// and given a shift, that is a another point of the cirvle, diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index a6692216a..e8ffa905d 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -8,6 +8,7 @@ use crate::{ fft::cpu::bit_reversing::in_place_bit_reverse_permute, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +use alloc::vec::Vec; /// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 4c7a8b115..1d0351388 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -3,6 +3,7 @@ use crate::{ circle::cosets::Coset, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +use alloc::vec::Vec; #[derive(PartialEq)] pub enum TwiddlesConfig { From c7cfd8ffb41350bf51314bf02559a18a92b541d2 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Thu, 24 Oct 2024 12:26:58 -0300 Subject: [PATCH 63/93] remove macro --- math/src/circle/polynomial.rs | 1 + math/src/circle/twiddles.rs | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index e8ffa905d..4f1f92527 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,3 +1,4 @@ +extern crate alloc; #[cfg(feature = "alloc")] use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 1d0351388..68fae42fb 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -19,8 +19,8 @@ pub fn get_twiddles( let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); // The first set of twiddles are all the y coordinates of the half coset. - let mut twiddles: Vec>> = - vec![half_domain_points.iter().map(|p| p.y).collect()]; + let mut twiddles: Vec>> = Vec::new(); + twiddles.push(half_domain_points.iter().map(|p| p.y).collect()); if domain.log_2_size >= 2 { // The second set of twiddles are the x coordinates of the first half of the half coset. From fc68bea3c2079aa66572204c3e96569f84561edd Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Fri, 25 Oct 2024 10:22:33 -0300 Subject: [PATCH 64/93] add comment --- math/src/circle/point.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 13d97de3b..e4fc6000f 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -48,6 +48,8 @@ impl HasCircleParams for Mersenne31Field { impl HasCircleParams for Degree4ExtensionField { type FE = FieldElement; + // These parameters were taken from stwo's implementation: + // https://github.com/starkware-libs/stwo/blob/9cfd48af4e8ac5dd67643a92927c894066fa989c/crates/prover/src/core/circle.rs const CIRCLE_GENERATOR_X: Self::FE = Degree4ExtensionField::const_from_coefficients(1, 0, 478637715, 513582971); From 98e7dd26c1db1510c8e4cd18b20b70bcf05a747a Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 28 Oct 2024 11:16:02 -0300 Subject: [PATCH 65/93] init prover --- Cargo.toml | 2 +- math/src/field/fields/mersenne31/field.rs | 7 +++++++ provers/circle_stark/Cargo.toml | 16 ++++++++++++++++ provers/circle_stark/src/config.rs | 17 +++++++++++++++++ provers/circle_stark/src/lib.rs | 2 ++ provers/circle_stark/src/prover.rs | 22 ++++++++++++++++++++++ 6 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 provers/circle_stark/Cargo.toml create mode 100644 provers/circle_stark/src/config.rs create mode 100644 provers/circle_stark/src/lib.rs create mode 100644 provers/circle_stark/src/prover.rs diff --git a/Cargo.toml b/Cargo.toml index eed7f1be3..81a003780 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] -members = ["math", "crypto", "gpu", "benches", "provers/plonk", "provers/stark", "provers/groth16", "provers/groth16/arkworks-adapter", "provers/groth16/circom-adapter", "examples/merkle-tree-cli", "examples/prove-miden", "provers/winterfell_adapter", "examples/shamir_secret_sharing","examples/pinocchio", "examples/prove-verify-circom", "examples/baby-snark"] +members = ["math", "crypto", "gpu", "benches", "provers/plonk", "provers/stark", "provers/groth16", "provers/groth16/arkworks-adapter", "provers/groth16/circom-adapter", "examples/merkle-tree-cli", "examples/prove-miden", "provers/winterfell_adapter", "examples/shamir_secret_sharing","examples/pinocchio", "examples/prove-verify-circom", "examples/baby-snark", "provers/circle_stark"] exclude = ["ensure-no_std"] resolver = "2" diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 1c8b2dc58..f46482a02 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -1,3 +1,4 @@ +use crate::traits::{AsBytes, ByteConversion}; use crate::{ errors::CreationError, field::{ @@ -203,6 +204,12 @@ impl Display for FieldElement { } } +impl AsBytes for FieldElement { + fn as_bytes(&self) -> alloc::vec::Vec { + self.value().to_bytes_be() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/provers/circle_stark/Cargo.toml b/provers/circle_stark/Cargo.toml new file mode 100644 index 000000000..d66c2fd2d --- /dev/null +++ b/provers/circle_stark/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "circle_stark" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +lambdaworks-math = { workspace = true, features = [ + "std", + "lambdaworks-serde-binary", +] } +lambdaworks-crypto = { workspace = true, features = ["std", "serde"] } + +thiserror = "1.0.38" +itertools = "0.11.0" diff --git a/provers/circle_stark/src/config.rs b/provers/circle_stark/src/config.rs new file mode 100644 index 000000000..52df5fcef --- /dev/null +++ b/provers/circle_stark/src/config.rs @@ -0,0 +1,17 @@ +use lambdaworks_crypto::merkle_tree::{ + backends::types::{BatchKeccak256Backend, Keccak256Backend}, + merkle::MerkleTree, +}; + +// Merkle Trees configuration + +// Security of both hashes should match + +pub type FriMerkleTreeBackend = Keccak256Backend; +pub type FriMerkleTree = MerkleTree>; + +pub const COMMITMENT_SIZE: usize = 32; +pub type Commitment = [u8; COMMITMENT_SIZE]; + +pub type BatchedMerkleTreeBackend = BatchKeccak256Backend; +pub type BatchedMerkleTree = MerkleTree>; diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs new file mode 100644 index 000000000..8dc915ef1 --- /dev/null +++ b/provers/circle_stark/src/lib.rs @@ -0,0 +1,2 @@ +pub mod config; +pub mod prover; diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs new file mode 100644 index 000000000..fbc7239eb --- /dev/null +++ b/provers/circle_stark/src/prover.rs @@ -0,0 +1,22 @@ +use super::config::{BatchedMerkleTree, Commitment}; +use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; + +pub struct FRIProof; + +pub struct CommitmentData { + pub(crate) trace_polys: Vec>, + pub(crate) lde_trace_merkle_tree: BatchedMerkleTree, + pub(crate) lde_trace_merkle_root: Commitment, +} + +pub fn prove(trace: Vec>) -> CommitmentData { + let trace_polys: Vec>; + let lde_trace_merkle_tree; + let lde_trace_merkle_root; + + CommitmentData { + trace_polys, + lde_trace_merkle_tree, + lde_trace_merkle_root, + } +} From e9f60787bcb93c03e52b18efb48fc96e64915b25 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 28 Oct 2024 16:21:08 -0300 Subject: [PATCH 66/93] add point eval --- math/src/circle/polynomial.rs | 54 ++++++++++++++++++++ math/src/field/fields/mersenne31/field.rs | 2 +- provers/circle_stark/src/prover.rs | 62 +++++++++++++++++------ 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 4f1f92527..fc3aa4247 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,4 +1,5 @@ extern crate alloc; +use super::point::{CircleError, CirclePoint}; #[cfg(feature = "alloc")] use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, @@ -60,6 +61,53 @@ pub fn interpolate_cfft( eval_ordered.iter().map(|coef| coef * factor).collect() } +/// Note: This implementation uses a straightforward approach and is intended for testing purposes only. +pub fn evaluate_point( + coef: &Vec>, + point: CirclePoint, +) -> FieldElement { + let order = coef.len(); + assert!( + order.is_power_of_two(), + "Coefficient length must be a power of 2" + ); + + let v_len = order.trailing_zeros() as usize; + + let mut a = point.x; + let mut v = Vec::with_capacity(v_len); + v.push(FieldElement::one()); + v.push(point.x); + for _ in 2..v_len { + a = a.square().double() - FieldElement::one(); + v.push(a); + } + + let mut result = FieldElement::zero(); + + for i in 0..order { + let mut term = coef[i]; + + if i % 2 == 1 { + term = term * point.y; + } + + let mut idx = i / 2; + let mut pos = 0; + while idx > 0 { + if idx % 2 == 1 { + term = term * v[pos + 1]; + } + idx /= 2; + pos += 1; + } + + result = result + term; + } + + result +} + #[cfg(test)] mod tests { use super::*; @@ -280,6 +328,12 @@ mod tests { FE::from(32), ]; let evals = evaluate_cfft(coeff.clone()); + + let coset = Coset::new_standard(5); + let coset_points = Coset::get_coset_points(&coset); + + assert_eq!(evals[4], evaluate_point(&coeff, coset_points[4].clone())); + let new_coeff = interpolate_cfft(evals); assert_eq!(coeff, new_coeff); diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index f46482a02..7654dc29b 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -206,7 +206,7 @@ impl Display for FieldElement { impl AsBytes for FieldElement { fn as_bytes(&self) -> alloc::vec::Vec { - self.value().to_bytes_be() + self.to_bytes_be() } } diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs index fbc7239eb..4949ac1cf 100644 --- a/provers/circle_stark/src/prover.rs +++ b/provers/circle_stark/src/prover.rs @@ -1,22 +1,52 @@ -use super::config::{BatchedMerkleTree, Commitment}; -use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use crate::config::FriMerkleTree; -pub struct FRIProof; +use super::config::Commitment; +use lambdaworks_math::{ + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, + circle::polynomial::{interpolate_cfft, evaluate_cfft} +}; -pub struct CommitmentData { - pub(crate) trace_polys: Vec>, - pub(crate) lde_trace_merkle_tree: BatchedMerkleTree, - pub(crate) lde_trace_merkle_root: Commitment, +const BLOW_UP_FACTOR: usize = 2; + +pub fn prove(trace: Vec>) -> Commitment { + + let lde_domain_size = trace.len() * BLOW_UP_FACTOR; + + // Returns the coef of the interpolating polinomial of the trace on a natural domain. + let mut trace_poly = interpolate_cfft(trace); + + // Padding with zeros the coefficients of the polynomial, so we can evaluate it in the lde domain. + trace_poly.resize(lde_domain_size, FieldElement::zero()); + let lde_eval = evaluate_cfft(trace_poly); + + let tree = FriMerkleTree::::build(&lde_eval).unwrap(); + let commitment = tree.root; + + commitment } -pub fn prove(trace: Vec>) -> CommitmentData { - let trace_polys: Vec>; - let lde_trace_merkle_tree; - let lde_trace_merkle_root; - CommitmentData { - trace_polys, - lde_trace_merkle_tree, - lde_trace_merkle_root, +#[cfg(test)] +mod tests { + + use super::*; + + type FE = FieldElement; + + #[test] + fn basic_test() { + let trace = vec![ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + ]; + + let commitmet = prove(trace); + println!("{:?}", commitmet); } -} +} \ No newline at end of file From 0a50704caa9de88b2238fd5e2efdd2c1033176f6 Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 28 Oct 2024 18:14:18 -0300 Subject: [PATCH 67/93] addition between referenced and non-referenced values --- math/src/circle/point.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index e4fc6000f..e512eb9e3 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -78,6 +78,24 @@ impl> Add for &CirclePoint { CirclePoint { x, y } } } +impl> Add> for CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: CirclePoint) -> Self::Output { + &self + &rhs + } +} +impl> Add> for &CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: CirclePoint) -> Self::Output { + self + &rhs + } +} +impl> Add<&CirclePoint> for CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: &CirclePoint) -> Self::Output { + &self + rhs + } +} /// Multiplication between a point and a scalar (i.e. group operation repeatedly): /// (x, y) * n = (x ,y) + ... + (x, y) n-times. From a045516f0f022d98f0b4d5f0dc014e154645ee9f Mon Sep 17 00:00:00 2001 From: Nicole Graus Date: Wed, 30 Oct 2024 11:15:13 -0300 Subject: [PATCH 68/93] Update math/src/circle/point.rs Co-authored-by: Ivan Litteri <67517699+ilitteri@users.noreply.github.com> --- math/src/circle/point.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index e512eb9e3..9c6e0662c 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -102,7 +102,8 @@ impl> Add<&CirclePoint> for CirclePoint { impl> Mul for CirclePoint { type Output = CirclePoint; - fn mul(self, mut scalar: u128) -> Self { + fn mul(self, scalar: u128) -> Self { + let mut scalar = scalar; let mut res = Self::zero(); let mut cur = self; loop { From c00a2236b949a225bc0ea33727864fd0991a5125 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 30 Oct 2024 11:25:02 -0300 Subject: [PATCH 69/93] explain why won't panic --- math/src/circle/polynomial.rs | 7 ++++++- math/src/circle/twiddles.rs | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 4f1f92527..3017ba184 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -39,6 +39,10 @@ pub fn evaluate_cfft( pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { + if eval.len() == 0 { + return vec![FieldElement::::zero()]; + } + // We get the twiddles for the interpolation. let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); @@ -53,7 +57,8 @@ pub fn interpolate_cfft( in_place_bit_reverse_permute::>(&mut eval_ordered); // The icfft returns all the coefficients multiplied by 2^n, the length of the evaluations. - // So we multiply every element that outputs the icfft byt the inverse of 2^n to get the actual coefficients. + // So we multiply every element that outputs the icfft by the inverse of 2^n to get the actual coefficients. + // Note that this `unwrap` will never panic because eval.len() != 0. let factor = (FieldElement::::from(eval.len() as u64)) .inv() .unwrap(); diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index 68fae42fb..6a07804c4 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -45,6 +45,8 @@ pub fn get_twiddles( if config == TwiddlesConfig::Interpolation { // For the interpolation, we need to take the inverse element of each twiddle in the default order. + // We can take inverse being sure that the `unwrap` won't panic because the twiddles are coordinates + // of elements of the coset (or their squares) so they can't be zero. twiddles.iter_mut().for_each(|x| { FieldElement::::inplace_batch_inverse(x).unwrap(); }); From e0fa3904aeb98a3146f4a5f73f05c4f801f76287 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 30 Oct 2024 11:32:27 -0300 Subject: [PATCH 70/93] add errors.rs --- math/src/circle/errors.rs | 4 ++ math/src/circle/mod.rs | 1 + math/src/circle/point.rs | 129 ++++++++++++++++++-------------------- 3 files changed, 67 insertions(+), 67 deletions(-) create mode 100644 math/src/circle/errors.rs diff --git a/math/src/circle/errors.rs b/math/src/circle/errors.rs new file mode 100644 index 000000000..51dcb720b --- /dev/null +++ b/math/src/circle/errors.rs @@ -0,0 +1,4 @@ +#[derive(Debug)] +pub enum CircleError { + PointDoesntSatisfyCircleEquation, +} diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs index f5a65721f..ac576194f 100644 --- a/math/src/circle/mod.rs +++ b/math/src/circle/mod.rs @@ -1,5 +1,6 @@ pub mod cfft; pub mod cosets; +pub mod errors; pub mod point; pub mod polynomial; pub mod twiddles; diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 9c6e0662c..24759a043 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,3 +1,4 @@ +use super::errors::CircleError; use crate::field::traits::IsField; use crate::field::{ element::FieldElement, @@ -9,16 +10,73 @@ use core::ops::{Add, Mul}; /// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have /// additive notation and is as follows: /// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) - #[derive(Debug, Clone)] pub struct CirclePoint { pub x: FieldElement, pub y: FieldElement, } -#[derive(Debug)] -pub enum CircleError { - PointDoesntSatisfyCircleEquation, +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::PointDoesntSatisfyCircleEquation) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } + + /// Computes 2(x, y) = (2x^2 - 1, 2xy). + pub fn double(self) -> Self { + Self::new( + self.x.square().double() - FieldElement::one(), + self.x.double() * self.y, + ) + .unwrap() + } + + /// Computes 2^n * (x, y). + pub fn repeated_double(self, n: u32) -> Self { + let mut res = self; + for _ in 0..n { + res = res.double(); + } + res + } + + /// Computes the inverse of the point. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation + /// because the norm of every point in the circle is one. + pub fn conjugate(self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn antipode(self) -> Self { + Self { + x: -self.x, + y: -self.y, + } + } + + pub const GENERATOR: Self = Self { + x: F::CIRCLE_GENERATOR_X, + y: F::CIRCLE_GENERATOR_Y, + }; + + /// Returns the generator of the subgroup of order n = 2^log_2_size. + /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. + pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { + Self::GENERATOR.repeated_double(31 - log_2_size) + } + + pub const ORDER: u128 = F::ORDER; } /// Parameters of the base field that we'll need to define its Circle. @@ -119,69 +177,6 @@ impl> Mul for CirclePoint { } } -impl> CirclePoint { - pub fn new(x: FieldElement, y: FieldElement) -> Result { - if x.square() + y.square() == FieldElement::one() { - Ok(Self { x, y }) - } else { - Err(CircleError::PointDoesntSatisfyCircleEquation) - } - } - - /// Neutral element of the Circle group (with additive notation). - pub fn zero() -> Self { - Self::new(FieldElement::one(), FieldElement::zero()).unwrap() - } - - /// Computes 2(x, y) = (2x^2 - 1, 2xy). - pub fn double(self) -> Self { - Self::new( - self.x.square().double() - FieldElement::one(), - self.x.double() * self.y, - ) - .unwrap() - } - - /// Computes 2^n * (x, y). - pub fn repeated_double(self, n: u32) -> Self { - let mut res = self; - for _ in 0..n { - res = res.double(); - } - res - } - - /// Computes the inverse of the point. - /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation - /// because the norm of every point in the circle is one. - pub fn conjugate(self) -> Self { - Self { - x: self.x, - y: -self.y, - } - } - - pub fn antipode(self) -> Self { - Self { - x: -self.x, - y: -self.y, - } - } - - pub const GENERATOR: Self = Self { - x: F::CIRCLE_GENERATOR_X, - y: F::CIRCLE_GENERATOR_Y, - }; - - /// Returns the generator of the subgroup of order n = 2^log_2_size. - /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. - pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { - Self::GENERATOR.repeated_double(31 - log_2_size) - } - - pub const ORDER: u128 = F::ORDER; -} - #[cfg(test)] mod tests { use super::*; From 3ed8ac1f893ab655532fd2ce7e4c29a53dca0318 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 30 Oct 2024 11:47:53 -0300 Subject: [PATCH 71/93] fix vec --- math/src/circle/polynomial.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 3017ba184..d9764c475 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -40,7 +40,9 @@ pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { if eval.len() == 0 { - return vec![FieldElement::::zero()]; + let mut poly = Vec::new(); + poly.push(FieldElement::::zero()); + return poly; } // We get the twiddles for the interpolation. From a7161f252ea6796518eb4da6afd438c835bd05fa Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 30 Oct 2024 13:15:17 -0300 Subject: [PATCH 72/93] Evaluate and interpolate functions have non-mutable inputs --- math/src/circle/cfft.rs | 2 +- math/src/circle/polynomial.rs | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 3b327a954..5e360b1bf 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -78,7 +78,7 @@ pub fn icfft( /// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. /// TODO: This can be optimized by performing in-place value swapping (WIP). pub fn order_cfft_result_naive( - input: &mut [FieldElement], + input: &[FieldElement], ) -> Vec> { let mut result = Vec::new(); let length = input.len(); diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index d9764c475..a3ee2fadc 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -16,8 +16,10 @@ use alloc::vec::Vec; /// Note that coeff has to be a vector with length a power of two 2^n. #[cfg(feature = "alloc")] pub fn evaluate_cfft( - mut coeff: Vec>, + coeff: Vec>, ) -> Vec> { + let mut coeff = coeff; + // We get the twiddles for the Evaluation. let domain_log_2_size: u32 = coeff.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); @@ -29,19 +31,21 @@ pub fn evaluate_cfft( cfft(&mut coeff, twiddles); // The cfft returns the evaluations in a certain order, so we permute them to get the natural order. - order_cfft_result_naive(&mut coeff) + order_cfft_result_naive(&coeff) } /// Interpolates the 2^n evaluations of a two-variables polynomial of degree 2^n - 1 on the points of the standard coset of size 2^n. /// As a result we obtain the coefficients of the polynomial in the basis: {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} /// Note that eval has to be a vector of length a power of two 2^n. +/// If the vector of evaluations is empty, it returns an empty vector. #[cfg(feature = "alloc")] pub fn interpolate_cfft( - mut eval: Vec>, + eval: Vec>, ) -> Vec> { - if eval.len() == 0 { - let mut poly = Vec::new(); - poly.push(FieldElement::::zero()); + let mut eval = eval; + + if eval.is_empty() { + let poly: Vec> = Vec::new(); return poly; } @@ -135,8 +139,9 @@ mod tests { expected_result.push(point_eval); } + let input_vec = input.to_vec(); // We evaluate the polynomial using now the cfft. - let result = evaluate_cfft(input.to_vec()); + let result = evaluate_cfft(input_vec); let slice_result: &[FE] = &result; assert_eq!(slice_result, expected_result); From b88794f3cbc55c2a779b3fbfa510f6d7bf7cb5a3 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 30 Oct 2024 13:29:37 -0300 Subject: [PATCH 73/93] fix clippy --- math/src/circle/cfft.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 5e360b1bf..f060e02d6 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -121,9 +121,9 @@ mod tests { fn ordering_cfft_result_works_for_4_points() { let expected_slice = [FE::from(0), FE::from(1), FE::from(2), FE::from(3)]; - let mut slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)]; + let slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)]; - let res = order_cfft_result_naive(&mut slice); + let res = order_cfft_result_naive(&slice); assert_eq!(res, expected_slice) } @@ -149,7 +149,7 @@ mod tests { FE::from(15), ]; - let mut slice = [ + let slice = [ FE::from(0), FE::from(2), FE::from(4), @@ -168,7 +168,7 @@ mod tests { FE::from(1), ]; - let res = order_cfft_result_naive(&mut slice); + let res = order_cfft_result_naive(&slice); assert_eq!(res, expected_slice) } From b2e9b9d0948b6c92ee153247bcf66324c136aeb7 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 30 Oct 2024 18:37:49 -0300 Subject: [PATCH 74/93] MulAssign for points and double function takes a reference --- math/src/circle/cosets.rs | 4 +- math/src/circle/point.rs | 177 +++++++++++++++++++++++--------------- 2 files changed, 109 insertions(+), 72 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 709ea76fc..957097efb 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -3,8 +3,8 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; use alloc::vec::Vec; -/// Given g_n, a generator of the subgroup of the circle of size n, -/// and given a shift, that is a another point of the cirvle, +/// Given g_n, a generator of the subgroup of size n of the circle, i.e. , +/// and given a shift, that is a another point of the circle, /// we define the coset shift + which is the set of all the points in /// plus the shift. /// For example, if = {p1, p2, p3, p4}, then g_8 + = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 24759a043..6676d0e5f 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,82 +1,24 @@ -use super::errors::CircleError; use crate::field::traits::IsField; use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use core::ops::{Add, Mul}; +use core::ops::{Add, AddAssign, Mul, MulAssign}; /// Given a Field F, we implement here the Group which consists of all the points (x, y) such as /// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have /// additive notation and is as follows: -/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c). + #[derive(Debug, Clone)] pub struct CirclePoint { pub x: FieldElement, pub y: FieldElement, } -impl> CirclePoint { - pub fn new(x: FieldElement, y: FieldElement) -> Result { - if x.square() + y.square() == FieldElement::one() { - Ok(Self { x, y }) - } else { - Err(CircleError::PointDoesntSatisfyCircleEquation) - } - } - - /// Neutral element of the Circle group (with additive notation). - pub fn zero() -> Self { - Self::new(FieldElement::one(), FieldElement::zero()).unwrap() - } - - /// Computes 2(x, y) = (2x^2 - 1, 2xy). - pub fn double(self) -> Self { - Self::new( - self.x.square().double() - FieldElement::one(), - self.x.double() * self.y, - ) - .unwrap() - } - - /// Computes 2^n * (x, y). - pub fn repeated_double(self, n: u32) -> Self { - let mut res = self; - for _ in 0..n { - res = res.double(); - } - res - } - - /// Computes the inverse of the point. - /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation - /// because the norm of every point in the circle is one. - pub fn conjugate(self) -> Self { - Self { - x: self.x, - y: -self.y, - } - } - - pub fn antipode(self) -> Self { - Self { - x: -self.x, - y: -self.y, - } - } - - pub const GENERATOR: Self = Self { - x: F::CIRCLE_GENERATOR_X, - y: F::CIRCLE_GENERATOR_Y, - }; - - /// Returns the generator of the subgroup of order n = 2^log_2_size. - /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. - pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { - Self::GENERATOR.repeated_double(31 - log_2_size) - } - - pub const ORDER: u128 = F::ORDER; +#[derive(Debug)] +pub enum CircleError { + PointDoesntSatisfyCircleEquation, } /// Parameters of the base field that we'll need to define its Circle. @@ -136,7 +78,7 @@ impl> Add for &CirclePoint { CirclePoint { x, y } } } -impl> Add> for CirclePoint { +impl> Add for CirclePoint { type Output = CirclePoint; fn add(self, rhs: CirclePoint) -> Self::Output { &self + &rhs @@ -154,28 +96,123 @@ impl> Add<&CirclePoint> for CirclePoint { &self + rhs } } +impl> AddAssign<&CirclePoint> for CirclePoint { + fn add_assign(&mut self, rhs: &CirclePoint) { + *self = &*self + rhs; + } +} +impl> AddAssign> for CirclePoint { + fn add_assign(&mut self, rhs: CirclePoint) { + *self += &rhs; + } +} /// Multiplication between a point and a scalar (i.e. group operation repeatedly): /// (x, y) * n = (x ,y) + ... + (x, y) n-times. -impl> Mul for CirclePoint { +impl> Mul for &CirclePoint { type Output = CirclePoint; - fn mul(self, scalar: u128) -> Self { + fn mul(self, scalar: u128) -> Self::Output { let mut scalar = scalar; - let mut res = Self::zero(); - let mut cur = self; + let mut res = CirclePoint::::zero(); + let mut cur = self.clone(); loop { if scalar == 0 { return res; } if scalar & 1 == 1 { - res = &res + &cur; + res += &cur; } cur = cur.double(); scalar >>= 1; } } } +impl> Mul for CirclePoint { + type Output = CirclePoint; + fn mul(self, scalar: u128) -> Self::Output { + &self * scalar + } +} +impl> MulAssign for CirclePoint { + fn mul_assign(&mut self, scalar: u128) { + let mut scalar = scalar; + let mut res = CirclePoint::::zero(); + loop { + if scalar == 0 { + *self = res.clone(); + } + if scalar & 1 == 1 { + res += &*self; + } + *self = self.double(); + scalar >>= 1; + } + } +} + +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::PointDoesntSatisfyCircleEquation) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } + + /// Computes 2(x, y) = (2x^2 - 1, 2xy). + pub fn double(&self) -> Self { + Self::new( + self.x.square().double() - FieldElement::one(), + self.x.double() * self.y.clone(), + ) + .unwrap() + } + + /// Computes 2^n * (x, y). + pub fn repeated_double(self, n: u32) -> Self { + let mut res = self; + for _ in 0..n { + res = res.double(); + } + res + } + + /// Computes the inverse of the point. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation + /// because the norm of every point in the circle is one. + pub fn conjugate(self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn antipode(self) -> Self { + Self { + x: -self.x, + y: -self.y, + } + } + + pub const GENERATOR: Self = Self { + x: F::CIRCLE_GENERATOR_X, + y: F::CIRCLE_GENERATOR_Y, + }; + + /// Returns the generator of the subgroup of order n = 2^log_2_size. + /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. + pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { + Self::GENERATOR.repeated_double(31 - log_2_size) + } + + pub const ORDER: u128 = F::ORDER; +} #[cfg(test)] mod tests { From 26d60788696944ab121dabd1a6d60c287140606c Mon Sep 17 00:00:00 2001 From: Nicole Date: Thu, 31 Oct 2024 10:29:56 -0300 Subject: [PATCH 75/93] Revert "MulAssign for points and double function takes a reference" This reverts commit b2e9b9d0948b6c92ee153247bcf66324c136aeb7. --- math/src/circle/cosets.rs | 4 +- math/src/circle/point.rs | 177 +++++++++++++++----------------------- 2 files changed, 72 insertions(+), 109 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 957097efb..709ea76fc 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -3,8 +3,8 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; use alloc::vec::Vec; -/// Given g_n, a generator of the subgroup of size n of the circle, i.e. , -/// and given a shift, that is a another point of the circle, +/// Given g_n, a generator of the subgroup of the circle of size n, +/// and given a shift, that is a another point of the cirvle, /// we define the coset shift + which is the set of all the points in /// plus the shift. /// For example, if = {p1, p2, p3, p4}, then g_8 + = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 6676d0e5f..24759a043 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -1,24 +1,82 @@ +use super::errors::CircleError; use crate::field::traits::IsField; use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use core::ops::{Add, AddAssign, Mul, MulAssign}; +use core::ops::{Add, Mul}; /// Given a Field F, we implement here the Group which consists of all the points (x, y) such as /// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have /// additive notation and is as follows: -/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c). - +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) #[derive(Debug, Clone)] pub struct CirclePoint { pub x: FieldElement, pub y: FieldElement, } -#[derive(Debug)] -pub enum CircleError { - PointDoesntSatisfyCircleEquation, +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::PointDoesntSatisfyCircleEquation) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } + + /// Computes 2(x, y) = (2x^2 - 1, 2xy). + pub fn double(self) -> Self { + Self::new( + self.x.square().double() - FieldElement::one(), + self.x.double() * self.y, + ) + .unwrap() + } + + /// Computes 2^n * (x, y). + pub fn repeated_double(self, n: u32) -> Self { + let mut res = self; + for _ in 0..n { + res = res.double(); + } + res + } + + /// Computes the inverse of the point. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation + /// because the norm of every point in the circle is one. + pub fn conjugate(self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn antipode(self) -> Self { + Self { + x: -self.x, + y: -self.y, + } + } + + pub const GENERATOR: Self = Self { + x: F::CIRCLE_GENERATOR_X, + y: F::CIRCLE_GENERATOR_Y, + }; + + /// Returns the generator of the subgroup of order n = 2^log_2_size. + /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. + pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { + Self::GENERATOR.repeated_double(31 - log_2_size) + } + + pub const ORDER: u128 = F::ORDER; } /// Parameters of the base field that we'll need to define its Circle. @@ -78,7 +136,7 @@ impl> Add for &CirclePoint { CirclePoint { x, y } } } -impl> Add for CirclePoint { +impl> Add> for CirclePoint { type Output = CirclePoint; fn add(self, rhs: CirclePoint) -> Self::Output { &self + &rhs @@ -96,123 +154,28 @@ impl> Add<&CirclePoint> for CirclePoint { &self + rhs } } -impl> AddAssign<&CirclePoint> for CirclePoint { - fn add_assign(&mut self, rhs: &CirclePoint) { - *self = &*self + rhs; - } -} -impl> AddAssign> for CirclePoint { - fn add_assign(&mut self, rhs: CirclePoint) { - *self += &rhs; - } -} /// Multiplication between a point and a scalar (i.e. group operation repeatedly): /// (x, y) * n = (x ,y) + ... + (x, y) n-times. -impl> Mul for &CirclePoint { +impl> Mul for CirclePoint { type Output = CirclePoint; - fn mul(self, scalar: u128) -> Self::Output { + fn mul(self, scalar: u128) -> Self { let mut scalar = scalar; - let mut res = CirclePoint::::zero(); - let mut cur = self.clone(); + let mut res = Self::zero(); + let mut cur = self; loop { if scalar == 0 { return res; } if scalar & 1 == 1 { - res += &cur; + res = &res + &cur; } cur = cur.double(); scalar >>= 1; } } } -impl> Mul for CirclePoint { - type Output = CirclePoint; - fn mul(self, scalar: u128) -> Self::Output { - &self * scalar - } -} -impl> MulAssign for CirclePoint { - fn mul_assign(&mut self, scalar: u128) { - let mut scalar = scalar; - let mut res = CirclePoint::::zero(); - loop { - if scalar == 0 { - *self = res.clone(); - } - if scalar & 1 == 1 { - res += &*self; - } - *self = self.double(); - scalar >>= 1; - } - } -} - -impl> CirclePoint { - pub fn new(x: FieldElement, y: FieldElement) -> Result { - if x.square() + y.square() == FieldElement::one() { - Ok(Self { x, y }) - } else { - Err(CircleError::PointDoesntSatisfyCircleEquation) - } - } - - /// Neutral element of the Circle group (with additive notation). - pub fn zero() -> Self { - Self::new(FieldElement::one(), FieldElement::zero()).unwrap() - } - - /// Computes 2(x, y) = (2x^2 - 1, 2xy). - pub fn double(&self) -> Self { - Self::new( - self.x.square().double() - FieldElement::one(), - self.x.double() * self.y.clone(), - ) - .unwrap() - } - - /// Computes 2^n * (x, y). - pub fn repeated_double(self, n: u32) -> Self { - let mut res = self; - for _ in 0..n { - res = res.double(); - } - res - } - - /// Computes the inverse of the point. - /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation - /// because the norm of every point in the circle is one. - pub fn conjugate(self) -> Self { - Self { - x: self.x, - y: -self.y, - } - } - - pub fn antipode(self) -> Self { - Self { - x: -self.x, - y: -self.y, - } - } - - pub const GENERATOR: Self = Self { - x: F::CIRCLE_GENERATOR_X, - y: F::CIRCLE_GENERATOR_Y, - }; - - /// Returns the generator of the subgroup of order n = 2^log_2_size. - /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. - pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { - Self::GENERATOR.repeated_double(31 - log_2_size) - } - - pub const ORDER: u128 = F::ORDER; -} #[cfg(test)] mod tests { From a6738c7af8a693ceb7ef8989544699a5c79edc9a Mon Sep 17 00:00:00 2001 From: Nicole Date: Thu, 31 Oct 2024 10:46:18 -0300 Subject: [PATCH 76/93] MulAssign and AddAssign --- math/src/circle/cosets.rs | 4 +-- math/src/circle/point.rs | 53 ++++++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 709ea76fc..957097efb 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -3,8 +3,8 @@ use crate::circle::point::CirclePoint; use crate::field::fields::mersenne31::field::Mersenne31Field; use alloc::vec::Vec; -/// Given g_n, a generator of the subgroup of the circle of size n, -/// and given a shift, that is a another point of the cirvle, +/// Given g_n, a generator of the subgroup of size n of the circle, i.e. , +/// and given a shift, that is a another point of the circle, /// we define the coset shift + which is the set of all the points in /// plus the shift. /// For example, if = {p1, p2, p3, p4}, then g_8 + = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 24759a043..e0e7aa210 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -4,7 +4,7 @@ use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use core::ops::{Add, Mul}; +use core::ops::{Add, AddAssign, Mul, MulAssign}; /// Given a Field F, we implement here the Group which consists of all the points (x, y) such as /// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have @@ -31,10 +31,10 @@ impl> CirclePoint { } /// Computes 2(x, y) = (2x^2 - 1, 2xy). - pub fn double(self) -> Self { + pub fn double(&self) -> Self { Self::new( self.x.square().double() - FieldElement::one(), - self.x.double() * self.y, + self.x.double() * self.y.clone(), ) .unwrap() } @@ -129,14 +129,13 @@ impl> PartialEq for CirclePoint { /// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) impl> Add for &CirclePoint { type Output = CirclePoint; - fn add(self, other: Self) -> Self::Output { let x = &self.x * &other.x - &self.y * &other.y; let y = &self.x * &other.y + &self.y * &other.x; CirclePoint { x, y } } } -impl> Add> for CirclePoint { +impl> Add for CirclePoint { type Output = CirclePoint; fn add(self, rhs: CirclePoint) -> Self::Output { &self + &rhs @@ -154,28 +153,58 @@ impl> Add<&CirclePoint> for CirclePoint { &self + rhs } } - +impl> AddAssign<&CirclePoint> for CirclePoint { + fn add_assign(&mut self, rhs: &CirclePoint) { + *self = &*self + rhs; + } +} +impl> AddAssign> for CirclePoint { + fn add_assign(&mut self, rhs: CirclePoint) { + *self += &rhs; + } +} /// Multiplication between a point and a scalar (i.e. group operation repeatedly): /// (x, y) * n = (x ,y) + ... + (x, y) n-times. -impl> Mul for CirclePoint { +impl> Mul for &CirclePoint { type Output = CirclePoint; - - fn mul(self, scalar: u128) -> Self { + fn mul(self, scalar: u128) -> Self::Output { let mut scalar = scalar; - let mut res = Self::zero(); - let mut cur = self; + let mut res = CirclePoint::::zero(); + let mut cur = self.clone(); loop { if scalar == 0 { return res; } if scalar & 1 == 1 { - res = &res + &cur; + res += &cur; } cur = cur.double(); scalar >>= 1; } } } +impl> Mul for CirclePoint { + type Output = CirclePoint; + fn mul(self, scalar: u128) -> Self::Output { + &self * scalar + } +} +impl> MulAssign for CirclePoint { + fn mul_assign(&mut self, scalar: u128) { + let mut scalar = scalar; + let mut res = CirclePoint::::zero(); + loop { + if scalar == 0 { + *self = res.clone(); + } + if scalar & 1 == 1 { + res += &*self; + } + *self = self.double(); + scalar >>= 1; + } + } +} #[cfg(test)] mod tests { From d66d52d453305428968a8e046aa8e8ddb9a02282 Mon Sep 17 00:00:00 2001 From: Nicole Date: Thu, 31 Oct 2024 16:28:41 -0300 Subject: [PATCH 77/93] add vanishing polynomial of a coset --- provers/circle_stark/src/constraints.rs | 42 +++++++++++++++++++++++++ provers/circle_stark/src/lib.rs | 1 + 2 files changed, 43 insertions(+) create mode 100644 provers/circle_stark/src/constraints.rs diff --git a/provers/circle_stark/src/constraints.rs b/provers/circle_stark/src/constraints.rs new file mode 100644 index 000000000..76a7d7b25 --- /dev/null +++ b/provers/circle_stark/src/constraints.rs @@ -0,0 +1,42 @@ +use lambdaworks_math::{ + circle::point::CirclePoint, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; + +pub fn evaluate_vanishing_poly( + log_2_size: u32, + point: CirclePoint, +) -> FieldElement { + let mut x = point.x; + for _ in 1..log_2_size { + x = x.square().double() - FieldElement::one(); + } + x +} + +#[cfg(test)] +mod tests { + use super::*; + use lambdaworks_math::circle::cosets::Coset; + + type FE = FieldElement; + + #[test] + fn vanishing_poly_vanishes_in_coset() { + let log_2_size = 3; + let coset = Coset::new_standard(log_2_size); + let points = Coset::get_coset_points(&coset); + for point in points { + assert_eq!(evaluate_vanishing_poly(log_2_size, point), FE::zero()); + } + } + #[test] + fn vanishing_poly_doesnt_vanishe_outside_coset() { + let log_2_size = 3; + let coset = Coset::new_standard(log_2_size + 1); + let points = Coset::get_coset_points(&coset); + for point in points { + assert_ne!(evaluate_vanishing_poly(log_2_size, point), FE::zero()); + } + } +} diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs index 8dc915ef1..aee7cb008 100644 --- a/provers/circle_stark/src/lib.rs +++ b/provers/circle_stark/src/lib.rs @@ -1,2 +1,3 @@ pub mod config; +pub mod constraints; pub mod prover; From 3b165ee8486ec8ca55818b1a16f0fa06e9649715 Mon Sep 17 00:00:00 2001 From: Nicole Date: Fri, 1 Nov 2024 12:05:36 -0300 Subject: [PATCH 78/93] add comments for vanishing poly and zerofier --- math/src/circle/polynomial.rs | 2 +- provers/circle_stark/src/constraints.rs | 37 ++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 19ffbade9..77f0fd76c 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,5 +1,5 @@ extern crate alloc; -use super::point::{CircleError, CirclePoint}; +use super::point::CirclePoint; #[cfg(feature = "alloc")] use super::{ cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, diff --git a/provers/circle_stark/src/constraints.rs b/provers/circle_stark/src/constraints.rs index 76a7d7b25..302f8d003 100644 --- a/provers/circle_stark/src/constraints.rs +++ b/provers/circle_stark/src/constraints.rs @@ -3,9 +3,14 @@ use lambdaworks_math::{ field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; +/// Evaluate the vanishing polynomial of the standard coset of size 2^log_2_size in a point. +/// The vanishing polynomial of a coset is the polynomial that takes the value zero when evaluated +/// in all the points of the coset. +/// We are using that if we take a point in g_{2n} + and double it n-1 times, then +/// we'll get the point (0, 1) or (0, -1); so its coordinate x is always 0. pub fn evaluate_vanishing_poly( log_2_size: u32, - point: CirclePoint, + point: &CirclePoint, ) -> FieldElement { let mut x = point.x; for _ in 1..log_2_size { @@ -14,6 +19,16 @@ pub fn evaluate_vanishing_poly( x } +// Evaluate the polynomial that vanishes at a specific point in the domain at an arbitrary point. +// This use the "tangent" line to the domain in the vanish point. +// Check: https://vitalik.eth.limo/general/2024/07/23/circlestarks.html for details. +pub fn evaluate_single_point_zerofier( + vanish_point: CirclePoint, + eval_point: &CirclePoint, +) -> FieldElement { + (eval_point + vanish_point.conjugate()).x - FieldElement::::one() +} + #[cfg(test)] mod tests { use super::*; @@ -27,16 +42,30 @@ mod tests { let coset = Coset::new_standard(log_2_size); let points = Coset::get_coset_points(&coset); for point in points { - assert_eq!(evaluate_vanishing_poly(log_2_size, point), FE::zero()); + assert_eq!(evaluate_vanishing_poly(log_2_size, &point), FE::zero()); } } #[test] - fn vanishing_poly_doesnt_vanishe_outside_coset() { + fn vanishing_poly_doesnt_vanishes_outside_coset() { let log_2_size = 3; let coset = Coset::new_standard(log_2_size + 1); let points = Coset::get_coset_points(&coset); for point in points { - assert_ne!(evaluate_vanishing_poly(log_2_size, point), FE::zero()); + assert_ne!(evaluate_vanishing_poly(log_2_size, &point), FE::zero()); } } + + #[test] + fn single_point_zerofier_vanishes_only_in_vanish_point() { + let vanish_point = CirclePoint::GENERATOR; + let eval_point = &vanish_point * 3; + assert_eq!( + evaluate_single_point_zerofier(vanish_point.clone(), &vanish_point), + FE::zero() + ); + assert_ne!( + evaluate_single_point_zerofier(vanish_point.clone(), &eval_point), + FE::zero() + ); + } } From ca2903041fa71ae39b11c8f8b8d25c4c5f8e9744 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 4 Nov 2024 15:13:52 -0300 Subject: [PATCH 79/93] wip --- provers/circle_stark/Cargo.toml | 1 + .../circle_stark/examples/simple_fibonacci.rs | 176 ++++++++++++ provers/circle_stark/src/air.rs | 102 +++++++ provers/circle_stark/src/air_context.rs | 31 ++ .../circle_stark/src/constraints/boundary.rs | 200 +++++++++++++ .../circle_stark/src/constraints/evaluator.rs | 221 ++++++++++++++ provers/circle_stark/src/constraints/mod.rs | 3 + .../src/constraints/transition.rs | 239 ++++++++++++++++ provers/circle_stark/src/frame.rs | 42 +++ provers/circle_stark/src/lib.rs | 8 +- provers/circle_stark/src/table.rs | 112 ++++++++ provers/circle_stark/src/trace.rs | 269 ++++++++++++++++++ .../src/{constraints.rs => vanishing_poly.rs} | 0 13 files changed, 1403 insertions(+), 1 deletion(-) create mode 100644 provers/circle_stark/examples/simple_fibonacci.rs create mode 100644 provers/circle_stark/src/air.rs create mode 100644 provers/circle_stark/src/air_context.rs create mode 100644 provers/circle_stark/src/constraints/boundary.rs create mode 100644 provers/circle_stark/src/constraints/evaluator.rs create mode 100644 provers/circle_stark/src/constraints/mod.rs create mode 100644 provers/circle_stark/src/constraints/transition.rs create mode 100644 provers/circle_stark/src/frame.rs create mode 100644 provers/circle_stark/src/table.rs create mode 100644 provers/circle_stark/src/trace.rs rename provers/circle_stark/src/{constraints.rs => vanishing_poly.rs} (100%) diff --git a/provers/circle_stark/Cargo.toml b/provers/circle_stark/Cargo.toml index d66c2fd2d..f167233e8 100644 --- a/provers/circle_stark/Cargo.toml +++ b/provers/circle_stark/Cargo.toml @@ -14,3 +14,4 @@ lambdaworks-crypto = { workspace = true, features = ["std", "serde"] } thiserror = "1.0.38" itertools = "0.11.0" +serde = { version = "1.0", features = ["derive"] } diff --git a/provers/circle_stark/examples/simple_fibonacci.rs b/provers/circle_stark/examples/simple_fibonacci.rs new file mode 100644 index 000000000..204aa938c --- /dev/null +++ b/provers/circle_stark/examples/simple_fibonacci.rs @@ -0,0 +1,176 @@ +use crate::{ + constraints::{ + boundary::{BoundaryConstraint, BoundaryConstraints}, + transition::TransitionConstraint, + }, + context::AirContext, + frame::Frame, + proof::options::ProofOptions, + trace::TraceTable, + traits::AIR, +}; +use lambdaworks_math::field::{element::FieldElement, traits::IsFFTField}; +use std::marker::PhantomData; + +#[derive(Clone)] +struct FibConstraint { + phantom: PhantomData, +} + +impl FibConstraint { + pub fn new() -> Self { + Self { + phantom: PhantomData, + } + } +} + +impl TransitionConstraint for FibConstraint +where + F: IsFFTField + Send + Sync, +{ + fn degree(&self) -> usize { + 1 + } + + fn constraint_idx(&self) -> usize { + 0 + } + + fn end_exemptions(&self) -> usize { + 2 + } + + fn evaluate( + &self, + frame: &Frame, + transition_evaluations: &mut [FieldElement], + _periodic_values: &[FieldElement], + _rap_challenges: &[FieldElement], + ) { + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + let third_step = frame.get_evaluation_step(2); + + let a0 = first_step.get_main_evaluation_element(0, 0); + let a1 = second_step.get_main_evaluation_element(0, 0); + let a2 = third_step.get_main_evaluation_element(0, 0); + + let res = a2 - a1 - a0; + + transition_evaluations[self.constraint_idx()] = res; + } +} + +pub struct FibonacciAIR +where + F: IsFFTField, +{ + context: AirContext, + trace_length: usize, + pub_inputs: FibonacciPublicInputs, + constraints: Vec>>, +} + +#[derive(Clone, Debug)] +pub struct FibonacciPublicInputs +where + F: IsFFTField, +{ + pub a0: FieldElement, + pub a1: FieldElement, +} + +impl AIR for FibonacciAIR +where + F: IsFFTField + Send + Sync + 'static, +{ + type Field = F; + type FieldExtension = F; + type PublicInputs = FibonacciPublicInputs; + + const STEP_SIZE: usize = 1; + + fn new( + trace_length: usize, + pub_inputs: &Self::PublicInputs, + proof_options: &ProofOptions, + ) -> Self { + let constraints: Vec>> = + vec![Box::new(FibConstraint::new())]; + + let context = AirContext { + proof_options: proof_options.clone(), + trace_columns: 1, + transition_exemptions: vec![2], + transition_offsets: vec![0, 1, 2], + num_transition_constraints: constraints.len(), + }; + + Self { + pub_inputs: pub_inputs.clone(), + context, + trace_length, + constraints, + } + } + + fn composition_poly_degree_bound(&self) -> usize { + self.trace_length() + } + + fn transition_constraints(&self) -> &Vec>> { + &self.constraints + } + + fn boundary_constraints( + &self, + _rap_challenges: &[FieldElement], + ) -> BoundaryConstraints { + let a0 = BoundaryConstraint::new_simple_main(0, self.pub_inputs.a0.clone()); + let a1 = BoundaryConstraint::new_simple_main(1, self.pub_inputs.a1.clone()); + + BoundaryConstraints::from_constraints(vec![a0, a1]) + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn trace_length(&self) -> usize { + self.trace_length + } + + fn trace_layout(&self) -> (usize, usize) { + (1, 0) + } + + fn pub_inputs(&self) -> &Self::PublicInputs { + &self.pub_inputs + } + + fn compute_transition_verifier( + &self, + frame: &Frame, + periodic_values: &[FieldElement], + rap_challenges: &[FieldElement], + ) -> Vec> { + self.compute_transition_prover(frame, periodic_values, rap_challenges) + } +} + +pub fn fibonacci_trace( + initial_values: [FieldElement; 2], + trace_length: usize, +) -> TraceTable { + let mut ret: Vec> = vec![]; + + ret.push(initial_values[0].clone()); + ret.push(initial_values[1].clone()); + + for i in 2..(trace_length) { + ret.push(ret[i - 1].clone() + ret[i - 2].clone()); + } + + TraceTable::from_columns(vec![ret], 1, 1) +} diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs new file mode 100644 index 000000000..12757a902 --- /dev/null +++ b/provers/circle_stark/src/air.rs @@ -0,0 +1,102 @@ +use super::{ + constraints::boundary::BoundaryConstraints, frame::Frame, + trace::TraceTable, +}; +use crate::{air_context::AirContext, constraints::transition::TransitionConstraint, domain::Domain}; +use lambdaworks_crypto::fiat_shamir::is_transcript::IsTranscript; +use lambdaworks_math::{ + circle::point::CirclePoint, field::{ + element::FieldElement, + traits::{IsFFTField, IsField, IsSubFieldOf}, + }, polynomial::Polynomial +}; +use std::collections::HashMap; +type ZerofierGroupKey = (usize, usize, Option, Option, usize); +/// AIR is a representation of the Constraints +pub trait AIR { + type Field: IsFFTField + IsSubFieldOf + Send + Sync; + type FieldExtension: IsField + Send + Sync; + type PublicInputs; + + fn new(trace_length: usize, pub_inputs: &Self::PublicInputs) -> Self; + + /// Returns the amount trace columns. + fn trace_layout(&self) -> usize; + + fn composition_poly_degree_bound(&self) -> usize; + + /// The method called by the prover to evaluate the transitions corresponding to an evaluation frame. + /// In the case of the prover, the main evaluation table of the frame takes values in + /// `Self::Field`, since they are the evaluations of the main trace at the LDE domain. + fn compute_transition_prover( + &self, + frame: &Frame, + ) -> Vec> { + let mut evaluations = + vec![FieldElement::::zero(); self.num_transition_constraints()]; + self.transition_constraints() + .iter() + .for_each(|c| c.evaluate(frame, &mut evaluations)); + evaluations + } + + fn boundary_constraints(&self) -> BoundaryConstraints; + + /// The method called by the verifier to evaluate the transitions at the out of domain frame. + /// In the case of the verifier, both main and auxiliary tables of the evaluation frame take + /// values in `Self::FieldExtension`, since they are the evaluations of the trace polynomials + /// at the out of domain challenge. + /// In case `Self::Field` coincides with `Self::FieldExtension`, this method and + /// `compute_transition_prover` should return the same values. + fn compute_transition_verifier( + &self, + frame: &Frame, + ) -> Vec>; + + fn context(&self) -> &AirContext; + + fn trace_length(&self) -> usize; + + fn blowup_factor(&self) -> u8 { + 2 + } + + fn trace_group_generator(&self) -> CirclePoint { + let trace_length = self.trace_length(); + let log_2_length = trace_length.trailing_zeros(); + CirclePoint::get_generator_of_subgroup(log_2_length) + } + + fn num_transition_constraints(&self) -> usize { + self.context().num_transition_constraints + } + + fn pub_inputs(&self) -> &Self::PublicInputs; + + fn transition_constraints( + &self, + ) -> &Vec>>; + + fn transition_zerofier_evaluations( + &self, + domain: &Domain, + ) -> Vec>> { + let mut evals = vec![Vec::new(); self.num_transition_constraints()]; + let mut zerofier_groups: HashMap>> = + HashMap::new(); + + self.transition_constraints().iter().for_each(|c| { + let end_exemptions = c.end_exemptions(); + // This hashmap is used to avoid recomputing with an fft the same zerofier evaluation + // If there are multiple domain and subdomains it can be further optimized + // as to share computation between them + let zerofier_group_key = (end_exemptions); + zerofier_groups + .entry(zerofier_group_key) + .or_insert_with(|| c.zerofier_evaluations_on_extended_domain(domain)); + let zerofier_evaluations = zerofier_groups.get(&zerofier_group_key).unwrap(); + evals[c.constraint_idx()] = zerofier_evaluations.clone(); + }); + evals + } +} diff --git a/provers/circle_stark/src/air_context.rs b/provers/circle_stark/src/air_context.rs new file mode 100644 index 000000000..efee83d08 --- /dev/null +++ b/provers/circle_stark/src/air_context.rs @@ -0,0 +1,31 @@ +use std::collections::HashSet; + +#[derive(Clone, Debug)] +pub struct AirContext { + pub trace_columns: usize, + + /// This is a vector with the indices of all the rows that constitute + /// an evaluation frame. Note that, because of how we write all constraints + /// in one method (`compute_transitions`), this vector needs to include the + /// offsets that are needed to compute EVERY transition constraint, even if some + /// constraints don't use all of the indexes in said offsets. + pub transition_offsets: Vec, + pub transition_exemptions: Vec, + pub num_transition_constraints: usize, +} + +impl AirContext { + pub fn num_transition_constraints(&self) -> usize { + self.num_transition_constraints + } + + /// Returns the number of non-trivial different + /// transition exemptions. + pub fn num_transition_exemptions(&self) -> usize { + self.transition_exemptions + .iter() + .filter(|&x| *x != 0) + .collect::>() + .len() + } +} diff --git a/provers/circle_stark/src/constraints/boundary.rs b/provers/circle_stark/src/constraints/boundary.rs new file mode 100644 index 000000000..50a4bbb98 --- /dev/null +++ b/provers/circle_stark/src/constraints/boundary.rs @@ -0,0 +1,200 @@ +use itertools::Itertools; +use lambdaworks_math::{ + field::{element::FieldElement, traits::IsField}, + polynomial::Polynomial, +}; + +#[derive(Debug)] +/// Represents a boundary constraint that must hold in an execution +/// trace: +/// * col: The column of the trace where the constraint must hold +/// * step: The step (or row) of the trace where the constraint must hold +/// * value: The value the constraint must have in that column and step +pub struct BoundaryConstraint { + pub col: usize, + pub step: usize, + pub value: FieldElement, + pub is_aux: bool, +} + +impl BoundaryConstraint { + pub fn new_main(col: usize, step: usize, value: FieldElement) -> Self { + Self { + col, + step, + value, + is_aux: false, + } + } + + pub fn new_aux(col: usize, step: usize, value: FieldElement) -> Self { + Self { + col, + step, + value, + is_aux: true, + } + } + + /// Used for creating boundary constraints for a trace with only one column + pub fn new_simple_main(step: usize, value: FieldElement) -> Self { + Self { + col: 0, + step, + value, + is_aux: false, + } + } + + /// Used for creating boundary constraints for a trace with only one column + pub fn new_simple_aux(step: usize, value: FieldElement) -> Self { + Self { + col: 0, + step, + value, + is_aux: true, + } + } +} + +/// Data structure that stores all the boundary constraints that must +/// hold for the execution trace +#[derive(Default, Debug)] +pub struct BoundaryConstraints { + pub constraints: Vec>, +} + +impl BoundaryConstraints { + #[allow(dead_code)] + pub fn new() -> Self { + Self { + constraints: Vec::>::new(), + } + } + + /// To instantiate from a vector of BoundaryConstraint elements + pub fn from_constraints(constraints: Vec>) -> Self { + Self { constraints } + } + + /// Returns all the steps where boundary conditions exist for the given column + pub fn steps(&self, col: usize) -> Vec { + self.constraints + .iter() + .filter(|v| v.col == col) + .map(|c| c.step) + .collect() + } + + pub fn steps_for_boundary(&self) -> Vec { + self.constraints + .iter() + .unique_by(|elem| elem.step) + .map(|v| v.step) + .collect() + } + + pub fn cols_for_boundary(&self) -> Vec { + self.constraints + .iter() + .unique_by(|elem| elem.col) + .map(|v| v.col) + .collect() + } + + /// Given the primitive root of some domain, returns the domain values corresponding + /// to the steps where the boundary conditions hold. This is useful when interpolating + /// the boundary conditions, since we must know the x values + pub fn generate_roots_of_unity( + &self, + primitive_root: &FieldElement, + cols_trace: &[usize], + ) -> Vec>> { + cols_trace + .iter() + .map(|i| { + self.steps(*i) + .into_iter() + .map(|s| primitive_root.pow(s)) + .collect::>>() + }) + .collect::>>>() + } + + /// For every trace column, give all the values the trace must be equal to in + /// the steps where the boundary constraints hold + pub fn values(&self, cols_trace: &[usize]) -> Vec>> { + cols_trace + .iter() + .map(|i| { + self.constraints + .iter() + .filter(|c| c.col == *i) + .map(|c| c.value.clone()) + .collect() + }) + .collect() + } + + /// Computes the zerofier of the boundary quotient. The result is the + /// multiplication of each binomial that evaluates to zero in the domain + /// values where the boundary constraints must hold. + /// + /// Example: If there are boundary conditions in the third and fifth steps, + /// then the zerofier will be (x - w^3) * (x - w^5) + pub fn compute_zerofier( + &self, + primitive_root: &FieldElement, + col: usize, + ) -> Polynomial> { + self.steps(col).into_iter().fold( + Polynomial::new_monomial(FieldElement::::one(), 0), + |zerofier, step| { + let binomial = + Polynomial::new(&[-primitive_root.pow(step), FieldElement::::one()]); + // TODO: Implement the MulAssign trait for Polynomials? + zerofier * binomial + }, + ) + } +} + +#[cfg(test)] +mod test { + use lambdaworks_math::field::{ + fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, traits::IsFFTField, + }; + type PrimeField = Stark252PrimeField; + + use super::*; + + #[test] + fn zerofier_is_the_correct_one() { + let one = FieldElement::::one(); + + // Fibonacci constraints: + // * a0 = 1 + // * a1 = 1 + // * a7 = 32 + let a0 = BoundaryConstraint::new_simple_main(0, one); + let a1 = BoundaryConstraint::new_simple_main(1, one); + let result = BoundaryConstraint::new_simple_main(7, FieldElement::::from(32)); + + let constraints = BoundaryConstraints::from_constraints(vec![a0, a1, result]); + + let primitive_root = PrimeField::get_primitive_root_of_unity(3).unwrap(); + + // P_0(x) = (x - 1) + let a0_zerofier = Polynomial::new(&[-one, one]); + // P_1(x) = (x - w^1) + let a1_zerofier = Polynomial::new(&[-primitive_root.pow(1u32), one]); + // P_res(x) = (x - w^7) + let res_zerofier = Polynomial::new(&[-primitive_root.pow(7u32), one]); + + let expected_zerofier = a0_zerofier * a1_zerofier * res_zerofier; + + let zerofier = constraints.compute_zerofier(&primitive_root, 0); + + assert_eq!(expected_zerofier, zerofier); + } +} diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs new file mode 100644 index 000000000..140a104ab --- /dev/null +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -0,0 +1,221 @@ +use super::boundary::BoundaryConstraints; +#[cfg(all(debug_assertions, not(feature = "parallel")))] +use crate::debug::check_boundary_polys_divisibility; +use crate::domain::Domain; +use crate::trace::LDETraceTable; +use crate::traits::AIR; +use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain}; +use itertools::Itertools; +#[cfg(all(debug_assertions, not(feature = "parallel")))] +use lambdaworks_math::polynomial::Polynomial; +use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes}; +#[cfg(feature = "parallel")] +use rayon::{ + iter::IndexedParallelIterator, + prelude::{IntoParallelIterator, ParallelIterator}, +}; +#[cfg(feature = "instruments")] +use std::time::Instant; + +pub struct ConstraintEvaluator { + boundary_constraints: BoundaryConstraints, +} +impl ConstraintEvaluator { + pub fn new(air: &A, rap_challenges: &[FieldElement]) -> Self { + let boundary_constraints = air.boundary_constraints(rap_challenges); + + Self { + boundary_constraints, + } + } + + pub(crate) fn evaluate( + &self, + air: &A, + lde_trace: &LDETraceTable, + domain: &Domain, + transition_coefficients: &[FieldElement], + boundary_coefficients: &[FieldElement], + rap_challenges: &[FieldElement], + ) -> Vec> + where + FieldElement: AsBytes + Send + Sync, + FieldElement: AsBytes + Send + Sync, + A: Send + Sync, + { + let boundary_constraints = &self.boundary_constraints; + let number_of_b_constraints = boundary_constraints.constraints.len(); + let boundary_zerofiers_inverse_evaluations: Vec>> = + boundary_constraints + .constraints + .iter() + .map(|bc| { + let point = &domain.trace_primitive_root.pow(bc.step as u64); + let mut evals = domain + .lde_roots_of_unity_coset + .iter() + .map(|v| v.clone() - point) + .collect::>>(); + FieldElement::inplace_batch_inverse(&mut evals).unwrap(); + evals + }) + .collect::>>>(); + + #[cfg(all(debug_assertions, not(feature = "parallel")))] + let boundary_polys: Vec>> = Vec::new(); + + #[cfg(feature = "instruments")] + let timer = Instant::now(); + + let lde_periodic_columns = air + .get_periodic_column_polynomials() + .iter() + .map(|poly| { + evaluate_polynomial_on_lde_domain( + poly, + domain.blowup_factor, + domain.interpolation_domain_size, + &domain.coset_offset, + ) + }) + .collect::>>, FFTError>>() + .unwrap(); + + #[cfg(feature = "instruments")] + println!( + " Evaluating periodic columns on lde: {:#?}", + timer.elapsed() + ); + + #[cfg(feature = "instruments")] + let timer = Instant::now(); + + let boundary_polys_evaluations = boundary_constraints + .constraints + .iter() + .map(|constraint| { + if constraint.is_aux { + (0..lde_trace.num_rows()) + .map(|row| { + let v = lde_trace.get_aux(row, constraint.col); + v - &constraint.value + }) + .collect_vec() + } else { + (0..lde_trace.num_rows()) + .map(|row| { + let v = lde_trace.get_main(row, constraint.col); + v - &constraint.value + }) + .collect_vec() + } + }) + .collect_vec(); + + #[cfg(feature = "instruments")] + println!(" Created boundary polynomials: {:#?}", timer.elapsed()); + #[cfg(feature = "instruments")] + let timer = Instant::now(); + + #[cfg(feature = "parallel")] + let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let boundary_eval_iter = 0..domain.lde_roots_of_unity_coset.len(); + + let boundary_evaluation: Vec<_> = boundary_eval_iter + .map(|domain_index| { + (0..number_of_b_constraints) + .zip(boundary_coefficients) + .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { + acc + &boundary_zerofiers_inverse_evaluations[constraint_index] + [domain_index] + * beta + * &boundary_polys_evaluations[constraint_index][domain_index] + }) + }) + .collect(); + + #[cfg(feature = "instruments")] + println!( + " Evaluated boundary polynomials on LDE: {:#?}", + timer.elapsed() + ); + + #[cfg(all(debug_assertions, not(feature = "parallel")))] + let boundary_zerofiers = Vec::new(); + + #[cfg(all(debug_assertions, not(feature = "parallel")))] + check_boundary_polys_divisibility(boundary_polys, boundary_zerofiers); + + #[cfg(all(debug_assertions, not(feature = "parallel")))] + let mut transition_evaluations = Vec::new(); + + #[cfg(feature = "instruments")] + let timer = Instant::now(); + let zerofiers_evals = air.transition_zerofier_evaluations(domain); + #[cfg(feature = "instruments")] + println!( + " Evaluated transition zerofiers: {:#?}", + timer.elapsed() + ); + + // Iterate over all LDE domain and compute the part of the composition polynomial + // related to the transition constraints and add it to the already computed part of the + // boundary constraints. + + #[cfg(feature = "instruments")] + let timer = Instant::now(); + let evaluations_t_iter = 0..domain.lde_roots_of_unity_coset.len(); + + #[cfg(feature = "parallel")] + let boundary_evaluation = boundary_evaluation.into_par_iter(); + #[cfg(feature = "parallel")] + let evaluations_t_iter = evaluations_t_iter.into_par_iter(); + + let evaluations_t = evaluations_t_iter + .zip(boundary_evaluation) + .map(|(i, boundary)| { + let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); + + let periodic_values: Vec<_> = lde_periodic_columns + .iter() + .map(|col| col[i].clone()) + .collect(); + + // Compute all the transition constraints at this point of the LDE domain. + let evaluations_transition = + air.compute_transition_prover(&frame, &periodic_values, rap_challenges); + + #[cfg(all(debug_assertions, not(feature = "parallel")))] + transition_evaluations.push(evaluations_transition.clone()); + + // Add each term of the transition constraints to the composition polynomial, including the zerofier, + // the challenge and the exemption polynomial if it is necessary. + let acc_transition = itertools::izip!( + evaluations_transition, + &zerofiers_evals, + transition_coefficients + ) + .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { + // Zerofier evaluations are cyclical, so we only calculate one cycle. + // This means that here we have to wrap around + // Ex: Suppose the full zerofier vector is Z = [1,2,3,1,2,3] + // we will instead have calculated Z' = [1,2,3] + // Now if you need Z[4] this is equal to Z'[1] + let wrapped_idx = i % zerof_eval.len(); + acc + &zerof_eval[wrapped_idx] * eval * beta + }); + + acc_transition + boundary + }) + .collect(); + + #[cfg(feature = "instruments")] + println!( + " Evaluated transitions and accumulated results: {:#?}", + timer.elapsed() + ); + + evaluations_t + } +} diff --git a/provers/circle_stark/src/constraints/mod.rs b/provers/circle_stark/src/constraints/mod.rs new file mode 100644 index 000000000..3811523b5 --- /dev/null +++ b/provers/circle_stark/src/constraints/mod.rs @@ -0,0 +1,3 @@ +pub mod boundary; +pub mod evaluator; +pub mod transition; diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs new file mode 100644 index 000000000..d99ee45ba --- /dev/null +++ b/provers/circle_stark/src/constraints/transition.rs @@ -0,0 +1,239 @@ +use crate::domain::Domain; +use crate::frame::Frame; +use crate::prover::evaluate_polynomial_on_lde_domain; +use itertools::Itertools; +use lambdaworks_math::field::element::FieldElement; +use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; +use lambdaworks_math::circle::point::CirclePoint; +use num_integer::Integer; +use std::ops::Div; +/// TransitionConstraint represents the behaviour that a transition constraint +/// over the computation that wants to be proven must comply with. +pub trait TransitionConstraint: Send + Sync +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + /// The degree of the constraint interpreting it as a multivariate polynomial. + fn degree(&self) -> usize; + + /// The index of the constraint. + /// Each transition constraint should have one index in the range [0, N), + /// where N is the total number of transition constraints. + fn constraint_idx(&self) -> usize; + + /// The function representing the evaluation of the constraint over elements + /// of the trace table. + /// + /// Elements of the trace table are found in the `frame` input, and depending on the + /// constraint, elements of `periodic_values` and `rap_challenges` may be used in + /// the evaluation. + /// Once computed, the evaluation should be inserted in the `transition_evaluations` + /// vector, in the index corresponding to the constraint as given by `constraint_idx()`. + fn evaluate( + &self, + frame: &Frame, + transition_evaluations: &mut [FieldElement], + periodic_values: &[FieldElement], + ); + + /// The periodicity the constraint is applied over the trace. + /// + /// Default value is 1, meaning that the constraint is applied to every + /// step of the trace. + fn period(&self) -> usize { + 1 + } + + /// The offset with respect to the first trace row, where the constraint + /// is applied. + /// For example, if the constraint has periodicity 2 and offset 1, this means + /// the constraint will be applied over trace rows of index 1, 3, 5, etc. + /// + /// Default value is 0, meaning that the constraint is applied from the first + /// element of the trace on. + fn offset(&self) -> usize { + 0 + } + + /// For a more fine-grained description of where the constraint should apply, + /// an exemptions period can be defined. + /// This specifies the periodicity of the row indexes where the constraint should + /// NOT apply, within the row indexes where the constraint applies, as specified by + /// `period()` and `offset()`. + /// + /// Default value is None. + fn exemptions_period(&self) -> Option { + None + } + + /// The offset value for periodic exemptions. Check documentation of `period()`, + /// `offset()` and `exemptions_period` for a better understanding. + fn periodic_exemptions_offset(&self) -> Option { + None + } + + /// The number of exemptions at the end of the trace. + /// + /// This method's output defines what trace elements should not be considered for + /// the constraint evaluation at the end of the trace. For example, for a fibonacci + /// computation that has to use the result 2 following steps, this method is defined + /// to return the value 2. + fn end_exemptions(&self) -> usize; + + /// Evaluate the `eval_point` in the polynomial that vanishes in all the exemptions points. + fn evaluate_end_exemptions_poly( + &self, + eval_point: CirclePoint, + trace_group_generator: &CirclePoint, + trace_length: usize, + ) -> FieldElement { + let one = FieldElement::::one(); + if self.end_exemptions() == 0 { + return one; + } + let period = self.period(); + // This accumulates evaluations of the point at the zerofier at all the offsets positions. + (1..=self.end_exemptions()) + .map(|exemption| trace_group_generator * (trace_length - exemption * period)) + .fold(one, |acc, vanishing_point| { + acc * ((eval_point + vanishing_point.conjugate()).x - one) + }) + } + + /// Compute evaluations of the constraints zerofier over a LDE domain. + #[allow(unstable_name_collisions)] + fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { + let blowup_factor = domain.blowup_factor; + let trace_length = domain.trace_roots_of_unity.len(); + let trace_primitive_root = &domain.trace_primitive_root; + let coset_offset = &domain.coset_offset; + let lde_root_order = u64::from((blowup_factor * trace_length).trailing_zeros()); + let lde_root = F::get_primitive_root_of_unity(lde_root_order).unwrap(); + + let end_exemptions_poly = self.end_exemptions_poly(trace_primitive_root, trace_length); + + // If there is an exemptions period defined for this constraint, the evaluations are calculated directly + // by computing P_exemptions(x) / Zerofier(x) + if let Some(exemptions_period) = self.exemptions_period() { + // FIXME: Rather than making this assertions here, it would be better to handle these + // errors or make these checks when the AIR is initialized. + + debug_assert!(exemptions_period.is_multiple_of(&self.period())); + + debug_assert!(self.periodic_exemptions_offset().is_some()); + + // The elements of the domain have order `trace_length * blowup_factor`, so the zerofier evaluations + // without the end exemptions, repeat their values after `blowup_factor * exemptions_period` iterations, + // so we only need to compute those. + let last_exponent = blowup_factor * exemptions_period; + + let evaluations: Vec<_> = (0..last_exponent) + .map(|exponent| { + let x = lde_root.pow(exponent); + let offset_times_x = coset_offset * &x; + let offset_exponent = trace_length * self.periodic_exemptions_offset().unwrap() + / exemptions_period; + + let numerator = offset_times_x.pow(trace_length / exemptions_period) + - trace_primitive_root.pow(offset_exponent); + let denominator = offset_times_x.pow(trace_length / self.period()) + - trace_primitive_root.pow(self.offset() * trace_length / self.period()); + + numerator.div(denominator) + }) + .collect(); + + // FIXME: Instead of computing this evaluations for each constraint, they can be computed + // once for every constraint with the same end exemptions (combination of end_exemptions() + // and period). + let end_exemption_evaluations = evaluate_polynomial_on_lde_domain( + &end_exemptions_poly, + blowup_factor, + domain.interpolation_domain_size, + coset_offset, + ) + .unwrap(); + + let cycled_evaluations = evaluations + .iter() + .cycle() + .take(end_exemption_evaluations.len()); + + std::iter::zip(cycled_evaluations, end_exemption_evaluations) + .map(|(eval, exemption_eval)| eval * exemption_eval) + .collect() + + // In this else branch, the zerofiers are computed as the numerator, then inverted + // using batch inverse and then multiplied by P_exemptions(x). This way we don't do + // useless divisions. + } else { + let last_exponent = blowup_factor * self.period(); + + let mut evaluations = (0..last_exponent) + .map(|exponent| { + let x = lde_root.pow(exponent); + (coset_offset * &x).pow(trace_length / self.period()) + - trace_primitive_root.pow(self.offset() * trace_length / self.period()) + }) + .collect_vec(); + + FieldElement::inplace_batch_inverse(&mut evaluations).unwrap(); + + // FIXME: Instead of computing this evaluations for each constraint, they can be computed + // once for every constraint with the same end exemptions (combination of end_exemptions() + // and period). + let end_exemption_evaluations = evaluate_polynomial_on_lde_domain( + &end_exemptions_poly, + blowup_factor, + domain.interpolation_domain_size, + coset_offset, + ) + .unwrap(); + + let cycled_evaluations = evaluations + .iter() + .cycle() + .take(end_exemption_evaluations.len()); + + std::iter::zip(cycled_evaluations, end_exemption_evaluations) + .map(|(eval, exemption_eval)| eval * exemption_eval) + .collect() + } + } + + /// Returns the evaluation of the zerofier corresponding to this constraint in some point + /// `z`, which could be in a field extension. + #[allow(unstable_name_collisions)] + fn evaluate_zerofier( + &self, + z: &FieldElement, + trace_primitive_root: &FieldElement, + trace_length: usize, + ) -> FieldElement { + let end_exemptions_poly = self.end_exemptions_poly(trace_primitive_root, trace_length); + + if let Some(exemptions_period) = self.exemptions_period() { + debug_assert!(exemptions_period.is_multiple_of(&self.period())); + + debug_assert!(self.periodic_exemptions_offset().is_some()); + + let periodic_exemptions_offset = self.periodic_exemptions_offset().unwrap(); + let offset_exponent = trace_length * periodic_exemptions_offset / exemptions_period; + + let numerator = -trace_primitive_root.pow(offset_exponent) + + z.pow(trace_length / exemptions_period); + let denominator = -trace_primitive_root + .pow(self.offset() * trace_length / self.period()) + + z.pow(trace_length / self.period()); + + return numerator.div(denominator) * end_exemptions_poly.evaluate(z); + } + + (-trace_primitive_root.pow(self.offset() * trace_length / self.period()) + + z.pow(trace_length / self.period())) + .inv() + .unwrap() + * end_exemptions_poly.evaluate(z) + } +} diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs new file mode 100644 index 000000000..fe9742296 --- /dev/null +++ b/provers/circle_stark/src/frame.rs @@ -0,0 +1,42 @@ +use crate::trace::LDETraceTable; +use itertools::Itertools; +use lambdaworks_math::field::{element::FieldElement, traits::IsField}; + +/// A frame represents a collection of trace steps. +/// The collected steps are all the necessary steps for +/// all transition costraints over a trace to be evaluated. +#[derive(Clone, Debug, PartialEq)] +pub struct Frame +where + F: IsField, +{ + steps: Vec>>, +} + +impl Frame { + + pub fn new(steps: Vec>) -> Self { + Self { steps } + } + + pub fn get_evaluation_step(&self, step: usize) -> &Vec> { + &self.steps[step] + } + + pub fn read_from_lde( + lde_trace: &LDETraceTable, + row: usize, + offsets: &[usize], + ) -> Self { + let num_rows = lde_trace.num_rows(); + + let lde_steps = offsets + .iter() + .map(|offset| { + let row = lde_trace.get_row(row + offset); + }) + .collect_vec(); + + Frame::new(lde_steps) + } +} diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs index aee7cb008..0e428e234 100644 --- a/provers/circle_stark/src/lib.rs +++ b/provers/circle_stark/src/lib.rs @@ -1,3 +1,9 @@ +pub mod air; pub mod config; -pub mod constraints; +pub mod vanishing_poly; +pub mod frame; pub mod prover; +pub mod table; +pub mod trace; +pub mod air_context; +pub mod constraints; \ No newline at end of file diff --git a/provers/circle_stark/src/table.rs b/provers/circle_stark/src/table.rs new file mode 100644 index 000000000..e6e45a127 --- /dev/null +++ b/provers/circle_stark/src/table.rs @@ -0,0 +1,112 @@ +use lambdaworks_math::field::{ + element::FieldElement, + traits::IsField, +}; + +/// A two-dimensional Table holding field elements, arranged in a row-major order. +/// This is the basic underlying data structure used for any two-dimensional component in the +/// the STARK protocol implementation, such as the `TraceTable` and the `EvaluationFrame`. +/// Since this struct is a representation of a two-dimensional table, all rows should have the same +/// length. +#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct Table { + pub data: Vec>, + pub width: usize, + pub height: usize, +} + +impl<'t, F: IsField> Table { + /// Crates a new Table instance from a one-dimensional array in row major order + /// and the intended width of the table. + pub fn new(data: Vec>, width: usize) -> Self { + // Check if the intented width is 0, used for creating an empty table. + if width == 0 { + return Self { + data: Vec::new(), + width, + height: 0, + }; + } + + // Check that the one-dimensional data makes sense to be interpreted as a 2D one. + // debug_assert!(crate::debug::validate_2d_structure(&data, width)); + let height = data.len() / width; + + Self { + data, + width, + height, + } + } + + /// Creates a Table instance from a vector of the intended columns. + pub fn from_columns(columns: Vec>>) -> Self { + if columns.is_empty() { + return Self::new(Vec::new(), 0); + } + let height = columns[0].len(); + + // Check that all columns have the same length for integrity + // debug_assert!(columns.iter().all(|c| c.len() == height)); + + let width = columns.len(); + let mut data = Vec::with_capacity(width * height); + + for row_idx in 0..height { + for column in columns.iter() { + data.push(column[row_idx].clone()); + } + } + + Self::new(data, width) + } + + /// Returns a vector of vectors of field elements representing the table rows + pub fn rows(&self) -> Vec>> { + self.data.chunks(self.width).map(|r| r.to_vec()).collect() + } + + /// Given a row index, returns a reference to that row as a slice of field elements. + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + let row_offset = row_idx * self.width; + &self.data[row_offset..row_offset + self.width] + } + + /// Given a row index, returns a mutable reference to that row as a slice of field elements. + pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { + let n_cols = self.width; + let row_offset = row_idx * n_cols; + &mut self.data[row_offset..row_offset + n_cols] + } + + /// Given a slice of field elements representing a row, appends it to + /// the end of the table. + pub fn append_row(&mut self, row: &[FieldElement]) { + debug_assert_eq!(row.len(), self.width); + self.data.extend_from_slice(row); + self.height += 1 + } + + /// Returns a reference to the last row of the table + pub fn last_row(&self) -> &[FieldElement] { + self.get_row(self.height - 1) + } + + /// Returns a vector of vectors of field elements representing the table + /// columns + pub fn columns(&self) -> Vec>> { + (0..self.width) + .map(|col_idx| { + (0..self.height) + .map(|row_idx| self.data[row_idx * self.width + col_idx].clone()) + .collect() + }) + .collect() + } + + /// Given row and column indexes, returns the stored field element in that position of the table. + pub fn get(&self, row: usize, col: usize) -> &FieldElement { + let idx = row * self.width + col; + &self.data[idx] + } +} diff --git a/provers/circle_stark/src/trace.rs b/provers/circle_stark/src/trace.rs new file mode 100644 index 000000000..636f28d05 --- /dev/null +++ b/provers/circle_stark/src/trace.rs @@ -0,0 +1,269 @@ +use crate::table::Table; +use itertools::Itertools; +use lambdaworks_math::circle::point::CirclePoint; +use lambdaworks_math::fft::errors::FFTError; +use lambdaworks_math::field::traits::{IsField, IsSubFieldOf}; +use lambdaworks_math::{ + field::{element::FieldElement, traits::IsFFTField}, + circle::polynomial::{interpolate_cfft, evaluate_point} +}; +#[cfg(feature = "parallel")] +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; + +/// A two-dimensional representation of an execution trace of the STARK +/// protocol. +/// +/// For the moment it is mostly a wrapper around the `Table` struct. It is a +/// layer above the raw two-dimensional table, with functionality relevant to the +/// STARK protocol, such as the step size (number of consecutive rows of the table) +/// of the computation being proven. +#[derive(Clone, Default, Debug, PartialEq, Eq)] +pub struct TraceTable { + pub table: Table, + pub num_columns: usize, +} + +impl TraceTable { + pub fn new( + data: Vec>, + num_columns: usize, + ) -> Self { + let table = Table::new(data, num_columns); + Self { + table, + num_columns, + } + } + + pub fn from_columns( + columns: Vec>>, + ) -> Self { + println!("COLUMNS LEN: {}", columns.len()); + let num_columns = columns.len(); + let table = Table::from_columns(columns); + Self { + table, + num_columns, + } + } + + pub fn empty() -> Self { + Self::new(Vec::new(), 0) + } + + pub fn is_empty(&self) -> bool { + self.table.width == 0 + } + + pub fn n_rows(&self) -> usize { + self.table.height + } + + pub fn n_cols(&self) -> usize { + self.table.width + } + + pub fn rows(&self) -> Vec>> { + self.table.rows() + } + + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + self.table.get_row(row_idx) + } + + pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { + self.table.get_row_mut(row_idx) + } + + pub fn last_row(&self) -> &[FieldElement] { + self.get_row(self.n_rows() - 1) + } + + pub fn columns(&self) -> Vec>> { + self.table.columns() + } + + /// Given a slice of integer numbers representing column indexes, merge these columns into + /// a one-dimensional vector. + /// + /// The particular way they are merged is not really important since this function is used to + /// aggreagate values distributed across various columns with no importance on their ordering, + /// such as to sort them. + pub fn merge_columns(&self, column_indexes: &[usize]) -> Vec> { + let mut data = Vec::with_capacity(self.n_rows() * column_indexes.len()); + for row_index in 0..self.n_rows() { + for column in column_indexes { + data.push(self.table.data[row_index * self.n_cols() + column].clone()); + } + } + data + } + + pub fn compute_trace_polys(&self) -> Vec>> + where + S: IsFFTField + IsSubFieldOf, + FieldElement: Send + Sync, + { + let columns = self.columns(); + #[cfg(feature = "parallel")] + let iter = columns.par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = columns.iter(); + + iter.map(|col| interpolate_cfft(col)) + .collect::>>, FFTError>>() + .unwrap() + } + + /// Given the padding length, appends the last row of the trace table + /// that many times. + /// This is useful for example when the desired trace length should be power + /// of two, and only the last row is the one that can be appended without affecting + /// the integrity of the constraints. + pub fn pad_with_last_row(&mut self, padding_len: usize) { + let last_row = self.last_row().to_vec(); + (0..padding_len).for_each(|_| { + self.table.append_row(&last_row); + }) + } + + /// Given a row index, a column index and a value, tries to set that location + /// of the trace with the given value. + /// The row_idx passed as argument may be greater than the max row index by 1. In this case, + /// last row of the trace is cloned, and the value is set in that cloned row. Then, the row is + /// appended to the end of the trace. + pub fn set_or_extend(&mut self, row_idx: usize, col_idx: usize, value: &FieldElement) { + debug_assert!(col_idx < self.n_cols()); + // NOTE: This is not very nice, but for how the function is being used at the moment, + // the passed `row_idx` should never be greater than `self.n_rows() + 1`. This is just + // an integrity check for ease in the developing process, we should think a better alternative + // in the future. + debug_assert!(row_idx <= self.n_rows() + 1); + if row_idx >= self.n_rows() { + let mut last_row = self.last_row().to_vec(); + last_row[col_idx] = value.clone(); + self.table.append_row(&last_row) + } else { + let row = self.get_row_mut(row_idx); + row[col_idx] = value.clone(); + } + } +} +pub struct LDETraceTable +where + F: IsField, +{ + pub(crate) table: Table, + pub(crate) blowup_factor: usize, +} + +impl LDETraceTable +where + F: IsField, +{ + pub fn new( + data: Vec>, + n_columns: usize, + blowup_factor: usize, + ) -> Self { + let table = Table::new(data, n_columns); + + Self { + table, + blowup_factor, + } + } + + pub fn from_columns( + columns: Vec>>, + blowup_factor: usize, + ) -> Self { + let table = Table::from_columns(columns); + + Self { + table, + blowup_factor, + } + } + + pub fn num_cols(&self) -> usize { + self.table.width + } + + pub fn num_rows(&self) -> usize { + self.table.height + } + + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + self.table.get_row(row_idx) + } + + pub fn get_table(&self, row: usize, col: usize) -> &FieldElement { + self.table.get(row, col) + } +} + +/// Given a slice of trace polynomials, an evaluation point `x`, the frame offsets +/// corresponding to the computation of the transitions, and a primitive root, +/// outputs the trace evaluations of each trace polynomial over the values used to +/// compute a transition. +/// Example: For a simple Fibonacci computation, if t(x) is the trace polynomial of +/// the computation, this will output evaluations t(x), t(g * x), t(g^2 * z). +pub fn get_trace_evaluations( + trace_polys: &[Vec>], + point: &CirclePoint, + frame_offsets: &[usize], + group_generator: &CirclePoint, +) -> Table +where + F: IsField, +{ + let evaluation_points = frame_offsets + .iter() + .map(|offset| (group_generator * offset) + point) + .collect_vec(); + + let evaluations = evaluation_points + .iter() + .map(|eval_point| { + trace_polys + .iter() + .map(|poly| evaluate_point(poly, eval_point)) + .collect_vec() + }) + .collect_vec(); + + let table_width = trace_polys.len(); + Table::new(evaluations, table_width) +} + +pub fn columns2rows(columns: Vec>>) -> Vec>> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + + (0..num_rows) + .map(|row_index| { + (0..num_cols) + .map(|col_index| columns[col_index][row_index].clone()) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod test { + use super::TraceTable; + use lambdaworks_math::field::{element::FieldElement, fields::u64_prime_field::F17}; + type FE = FieldElement; + + #[test] + fn test_cols() { + let col_1 = vec![FE::from(1), FE::from(2), FE::from(5), FE::from(13)]; + let col_2 = vec![FE::from(1), FE::from(3), FE::from(8), FE::from(21)]; + + let trace_table = TraceTable::from_columns(vec![col_1.clone(), col_2.clone()]); + let res_cols = trace_table.columns(); + + assert_eq!(res_cols, vec![col_1, col_2]); + } +} diff --git a/provers/circle_stark/src/constraints.rs b/provers/circle_stark/src/vanishing_poly.rs similarity index 100% rename from provers/circle_stark/src/constraints.rs rename to provers/circle_stark/src/vanishing_poly.rs From ed136c41df249a0ed87440ce7f9cf4886669743b Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 4 Nov 2024 17:23:59 -0300 Subject: [PATCH 80/93] save changes --- provers/circle_stark/src/air.rs | 13 +- .../src/constraints/transition.rs | 177 +++++------------- provers/circle_stark/src/domain.rs | 43 +++++ provers/circle_stark/src/frame.rs | 9 +- provers/circle_stark/src/lib.rs | 7 +- 5 files changed, 106 insertions(+), 143 deletions(-) create mode 100644 provers/circle_stark/src/domain.rs diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs index 12757a902..a2ac9bf8b 100644 --- a/provers/circle_stark/src/air.rs +++ b/provers/circle_stark/src/air.rs @@ -1,14 +1,13 @@ -use super::{ - constraints::boundary::BoundaryConstraints, frame::Frame, - trace::TraceTable, +use super::{constraints::boundary::BoundaryConstraints, frame::Frame}; +use crate::{ + air_context::AirContext, constraints::transition::TransitionConstraint, domain::Domain, }; -use crate::{air_context::AirContext, constraints::transition::TransitionConstraint, domain::Domain}; -use lambdaworks_crypto::fiat_shamir::is_transcript::IsTranscript; use lambdaworks_math::{ - circle::point::CirclePoint, field::{ + circle::point::CirclePoint, + field::{ element::FieldElement, traits::{IsFFTField, IsField, IsSubFieldOf}, - }, polynomial::Polynomial + }, }; use std::collections::HashMap; type ZerofierGroupKey = (usize, usize, Option, Option, usize); diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index d99ee45ba..3360415d3 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -1,11 +1,9 @@ use crate::domain::Domain; use crate::frame::Frame; -use crate::prover::evaluate_polynomial_on_lde_domain; use itertools::Itertools; +use lambdaworks_math::circle::{cosets::Coset, point::CirclePoint}; use lambdaworks_math::field::element::FieldElement; use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; -use lambdaworks_math::circle::point::CirclePoint; -use num_integer::Integer; use std::ops::Div; /// TransitionConstraint represents the behaviour that a transition constraint /// over the computation that wants to be proven must comply with. @@ -30,12 +28,7 @@ where /// the evaluation. /// Once computed, the evaluation should be inserted in the `transition_evaluations` /// vector, in the index corresponding to the constraint as given by `constraint_idx()`. - fn evaluate( - &self, - frame: &Frame, - transition_evaluations: &mut [FieldElement], - periodic_values: &[FieldElement], - ); + fn evaluate(&self, frame: &Frame, transition_evaluations: &mut [FieldElement]); /// The periodicity the constraint is applied over the trace. /// @@ -84,7 +77,8 @@ where /// Evaluate the `eval_point` in the polynomial that vanishes in all the exemptions points. fn evaluate_end_exemptions_poly( &self, - eval_point: CirclePoint, + eval_point: CirclePoint, + // `trace_group_generator` can be calculated with `trace_length` but it is better to precompute it trace_group_generator: &CirclePoint, trace_length: usize, ) -> FieldElement { @@ -102,138 +96,69 @@ where } /// Compute evaluations of the constraints zerofier over a LDE domain. + /// TODO: See if we can evaluate using cfft. + /// TODO: See if we can optimize computing only some evaluations and cycle them as in regular stark. #[allow(unstable_name_collisions)] fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { let blowup_factor = domain.blowup_factor; let trace_length = domain.trace_roots_of_unity.len(); - let trace_primitive_root = &domain.trace_primitive_root; - let coset_offset = &domain.coset_offset; - let lde_root_order = u64::from((blowup_factor * trace_length).trailing_zeros()); - let lde_root = F::get_primitive_root_of_unity(lde_root_order).unwrap(); - - let end_exemptions_poly = self.end_exemptions_poly(trace_primitive_root, trace_length); - - // If there is an exemptions period defined for this constraint, the evaluations are calculated directly - // by computing P_exemptions(x) / Zerofier(x) - if let Some(exemptions_period) = self.exemptions_period() { - // FIXME: Rather than making this assertions here, it would be better to handle these - // errors or make these checks when the AIR is initialized. - - debug_assert!(exemptions_period.is_multiple_of(&self.period())); - - debug_assert!(self.periodic_exemptions_offset().is_some()); - - // The elements of the domain have order `trace_length * blowup_factor`, so the zerofier evaluations - // without the end exemptions, repeat their values after `blowup_factor * exemptions_period` iterations, - // so we only need to compute those. - let last_exponent = blowup_factor * exemptions_period; - - let evaluations: Vec<_> = (0..last_exponent) - .map(|exponent| { - let x = lde_root.pow(exponent); - let offset_times_x = coset_offset * &x; - let offset_exponent = trace_length * self.periodic_exemptions_offset().unwrap() - / exemptions_period; - - let numerator = offset_times_x.pow(trace_length / exemptions_period) - - trace_primitive_root.pow(offset_exponent); - let denominator = offset_times_x.pow(trace_length / self.period()) - - trace_primitive_root.pow(self.offset() * trace_length / self.period()); - - numerator.div(denominator) - }) - .collect(); - - // FIXME: Instead of computing this evaluations for each constraint, they can be computed - // once for every constraint with the same end exemptions (combination of end_exemptions() - // and period). - let end_exemption_evaluations = evaluate_polynomial_on_lde_domain( - &end_exemptions_poly, - blowup_factor, - domain.interpolation_domain_size, - coset_offset, - ) - .unwrap(); - - let cycled_evaluations = evaluations - .iter() - .cycle() - .take(end_exemption_evaluations.len()); - - std::iter::zip(cycled_evaluations, end_exemption_evaluations) - .map(|(eval, exemption_eval)| eval * exemption_eval) - .collect() - - // In this else branch, the zerofiers are computed as the numerator, then inverted - // using batch inverse and then multiplied by P_exemptions(x). This way we don't do - // useless divisions. - } else { - let last_exponent = blowup_factor * self.period(); - - let mut evaluations = (0..last_exponent) - .map(|exponent| { - let x = lde_root.pow(exponent); - (coset_offset * &x).pow(trace_length / self.period()) - - trace_primitive_root.pow(self.offset() * trace_length / self.period()) - }) - .collect_vec(); - - FieldElement::inplace_batch_inverse(&mut evaluations).unwrap(); - - // FIXME: Instead of computing this evaluations for each constraint, they can be computed - // once for every constraint with the same end exemptions (combination of end_exemptions() - // and period). - let end_exemption_evaluations = evaluate_polynomial_on_lde_domain( - &end_exemptions_poly, - blowup_factor, - domain.interpolation_domain_size, - coset_offset, - ) - .unwrap(); - - let cycled_evaluations = evaluations - .iter() - .cycle() - .take(end_exemption_evaluations.len()); - - std::iter::zip(cycled_evaluations, end_exemption_evaluations) - .map(|(eval, exemption_eval)| eval * exemption_eval) - .collect() - } + let trace_log_2_size = trace_length.trailing_zeros(); + let lde_log_2_size = (blowup_factor * trace_length).trailing_zeros(); + let trace_group_generator = &domain.trace_primitive_root; + + // if let Some(exemptions_period) = self.exemptions_period() { + + // } else { + + let lde_coset = Coset::new_standard(lde_log_2_size); + let lde_points = Coset::get_points(lde_coset); + + let zerofier_evaluations = lde_points + .iter() + .map(|point| { + let mut x = point.x; + for _ in 1..trace_log_2_size { + x = x.square().double() - FieldElement::one(); + } + x + }) + .collect(); + FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); + + let end_exemptions_evaluations = lde_points + .iter() + .map(|point| { + self.evaluate_end_exemptions_poly(point, trace_group_generator, trace_length) + }) + .collect(); + + std::iter::zip(zerofier_evaluations, end_exemptions_evaluations) + .map(|(eval, exemptions_eval)| eval * exemptions_eval) + .collect() } /// Returns the evaluation of the zerofier corresponding to this constraint in some point - /// `z`, which could be in a field extension. + /// `eval_point`, (which is in the circle over the extension field). #[allow(unstable_name_collisions)] fn evaluate_zerofier( &self, - z: &FieldElement, - trace_primitive_root: &FieldElement, + eval_point: &CirclePoint, + trace_group_generator: &FieldElement, trace_length: usize, ) -> FieldElement { - let end_exemptions_poly = self.end_exemptions_poly(trace_primitive_root, trace_length); - - if let Some(exemptions_period) = self.exemptions_period() { - debug_assert!(exemptions_period.is_multiple_of(&self.period())); - - debug_assert!(self.periodic_exemptions_offset().is_some()); + // if let Some(exemptions_period) = self.exemptions_period() { - let periodic_exemptions_offset = self.periodic_exemptions_offset().unwrap(); - let offset_exponent = trace_length * periodic_exemptions_offset / exemptions_period; + // } else { - let numerator = -trace_primitive_root.pow(offset_exponent) - + z.pow(trace_length / exemptions_period); - let denominator = -trace_primitive_root - .pow(self.offset() * trace_length / self.period()) - + z.pow(trace_length / self.period()); + let end_exemptions_evaluation = + self.evaluate_end_exemptions_poly(eval_point, trace_group_generator, trace_length); - return numerator.div(denominator) * end_exemptions_poly.evaluate(z); + let trace_log_2_size = trace_length.trailing_zeros(); + let mut x = eval_point.x; + for _ in 1..trace_log_2_size { + x = x.square().double() - FieldElement::one(); } - (-trace_primitive_root.pow(self.offset() * trace_length / self.period()) - + z.pow(trace_length / self.period())) - .inv() - .unwrap() - * end_exemptions_poly.evaluate(z) + x.inv().unwrap() * end_exemptions_evaluation } } diff --git a/provers/circle_stark/src/domain.rs b/provers/circle_stark/src/domain.rs new file mode 100644 index 000000000..948f03aff --- /dev/null +++ b/provers/circle_stark/src/domain.rs @@ -0,0 +1,43 @@ +use lambdaworks_math::{ + circle::{cosets::Coset, point::CirclePoint}, + field::{element::FieldElement, traits::IsFFTField}, +}; + +use super::air::AIR; + +pub struct Domain { + pub(crate) trace_length: usize, + pub(crate) trace_log_2_length: u32, + pub(crate) blowup_factor: usize, + pub(crate) trace_coset_points: Vec>, + pub(crate) lde_coset_points: Vec>, + pub(crate) trace_group_generator: FieldElement, +} + +impl Domain { + pub fn new(air: &A) -> Self + where + A: AIR, + { + // Initial definitions + let trace_length = air.trace_length(); + let trace_log_2_length = trace_length.trailing_zeros(); + let blowup_factor = air.blowup_factor() as usize; + + // * Generate Coset + let trace_coset_points = Coset::get_coset_points(&Coset::new_standard(trace_log_2_length)); + let lde_coset_points = Coset::get_coset_points(&Coset::new_standard( + (blowup_factor * trace_length).trailing_zeros(), + )); + let trace_group_generator = CirclePoint::get_generator_of_subgroup(trace_log_2_length); + + Self { + trace_length, + trace_log_2_length, + blowup_factor, + trace_coset_points, + lde_coset_points, + trace_group_generator, + } + } +} diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs index fe9742296..9be7b7183 100644 --- a/provers/circle_stark/src/frame.rs +++ b/provers/circle_stark/src/frame.rs @@ -14,8 +14,7 @@ where } impl Frame { - - pub fn new(steps: Vec>) -> Self { + pub fn new(steps: Vec>>) -> Self { Self { steps } } @@ -23,11 +22,7 @@ impl Frame { &self.steps[step] } - pub fn read_from_lde( - lde_trace: &LDETraceTable, - row: usize, - offsets: &[usize], - ) -> Self { + pub fn read_from_lde(lde_trace: &LDETraceTable, row: usize, offsets: &[usize]) -> Self { let num_rows = lde_trace.num_rows(); let lde_steps = offsets diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs index 0e428e234..cb19024ce 100644 --- a/provers/circle_stark/src/lib.rs +++ b/provers/circle_stark/src/lib.rs @@ -1,9 +1,10 @@ pub mod air; +pub mod air_context; pub mod config; -pub mod vanishing_poly; +pub mod constraints; +pub mod domain; pub mod frame; pub mod prover; pub mod table; pub mod trace; -pub mod air_context; -pub mod constraints; \ No newline at end of file +pub mod vanishing_poly; From 268bdb92b176ba037e8301030bc004ac0ff2d0ba Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 4 Nov 2024 18:44:26 -0300 Subject: [PATCH 81/93] save changes --- provers/circle_stark/src/air.rs | 16 +++--- .../src/constraints/transition.rs | 42 ++++++++-------- provers/circle_stark/src/domain.rs | 11 +++-- provers/circle_stark/src/frame.rs | 4 +- provers/circle_stark/src/trace.rs | 49 +++++++------------ 5 files changed, 56 insertions(+), 66 deletions(-) diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs index a2ac9bf8b..7dbe3f2ca 100644 --- a/provers/circle_stark/src/air.rs +++ b/provers/circle_stark/src/air.rs @@ -3,7 +3,7 @@ use crate::{ air_context::AirContext, constraints::transition::TransitionConstraint, domain::Domain, }; use lambdaworks_math::{ - circle::point::CirclePoint, + circle::point::{CirclePoint, HasCircleParams}, field::{ element::FieldElement, traits::{IsFFTField, IsField, IsSubFieldOf}, @@ -13,8 +13,12 @@ use std::collections::HashMap; type ZerofierGroupKey = (usize, usize, Option, Option, usize); /// AIR is a representation of the Constraints pub trait AIR { - type Field: IsFFTField + IsSubFieldOf + Send + Sync; - type FieldExtension: IsField + Send + Sync; + type Field: IsFFTField + + IsSubFieldOf + + Send + + Sync + + HasCircleParams; + type FieldExtension: IsField + Send + Sync + HasCircleParams; type PublicInputs; fn new(trace_length: usize, pub_inputs: &Self::PublicInputs) -> Self; @@ -30,9 +34,9 @@ pub trait AIR { fn compute_transition_prover( &self, frame: &Frame, - ) -> Vec> { + ) -> Vec> { let mut evaluations = - vec![FieldElement::::zero(); self.num_transition_constraints()]; + vec![FieldElement::::zero(); self.num_transition_constraints()]; self.transition_constraints() .iter() .for_each(|c| c.evaluate(frame, &mut evaluations)); @@ -89,7 +93,7 @@ pub trait AIR { // This hashmap is used to avoid recomputing with an fft the same zerofier evaluation // If there are multiple domain and subdomains it can be further optimized // as to share computation between them - let zerofier_group_key = (end_exemptions); + let zerofier_group_key = end_exemptions; zerofier_groups .entry(zerofier_group_key) .or_insert_with(|| c.zerofier_evaluations_on_extended_domain(domain)); diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index 3360415d3..fe5933c0d 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -1,16 +1,14 @@ use crate::domain::Domain; use crate::frame::Frame; -use itertools::Itertools; -use lambdaworks_math::circle::{cosets::Coset, point::CirclePoint}; +use lambdaworks_math::circle::point::{CirclePoint, HasCircleParams}; use lambdaworks_math::field::element::FieldElement; use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; -use std::ops::Div; /// TransitionConstraint represents the behaviour that a transition constraint /// over the computation that wants to be proven must comply with. pub trait TransitionConstraint: Send + Sync where - F: IsSubFieldOf + IsFFTField + Send + Sync, - E: IsField + Send + Sync, + F: IsSubFieldOf + IsFFTField + Send + Sync + HasCircleParams, + E: IsField + Send + Sync + HasCircleParams, { /// The degree of the constraint interpreting it as a multivariate polynomial. fn degree(&self) -> usize; @@ -77,7 +75,7 @@ where /// Evaluate the `eval_point` in the polynomial that vanishes in all the exemptions points. fn evaluate_end_exemptions_poly( &self, - eval_point: CirclePoint, + eval_point: &CirclePoint, // `trace_group_generator` can be calculated with `trace_length` but it is better to precompute it trace_group_generator: &CirclePoint, trace_length: usize, @@ -89,9 +87,9 @@ where let period = self.period(); // This accumulates evaluations of the point at the zerofier at all the offsets positions. (1..=self.end_exemptions()) - .map(|exemption| trace_group_generator * (trace_length - exemption * period)) - .fold(one, |acc, vanishing_point| { - acc * ((eval_point + vanishing_point.conjugate()).x - one) + .map(|exemption| trace_group_generator * ((trace_length - exemption * period) as u128)) + .fold(one.clone(), |acc, vanishing_point| { + acc * ((eval_point + vanishing_point.conjugate()).x - &one) }) } @@ -101,31 +99,31 @@ where #[allow(unstable_name_collisions)] fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { let blowup_factor = domain.blowup_factor; - let trace_length = domain.trace_roots_of_unity.len(); + let trace_length = domain.trace_length; let trace_log_2_size = trace_length.trailing_zeros(); let lde_log_2_size = (blowup_factor * trace_length).trailing_zeros(); - let trace_group_generator = &domain.trace_primitive_root; + let trace_group_generator = &domain.trace_group_generator; // if let Some(exemptions_period) = self.exemptions_period() { // } else { - let lde_coset = Coset::new_standard(lde_log_2_size); - let lde_points = Coset::get_points(lde_coset); + let lde_points = &domain.lde_coset_points; - let zerofier_evaluations = lde_points + let mut zerofier_evaluations: Vec<_> = lde_points .iter() .map(|point| { - let mut x = point.x; + // TODO: Is there a way to avoid this clone()? + let mut x = point.x.clone(); for _ in 1..trace_log_2_size { - x = x.square().double() - FieldElement::one(); + x = x.square().double() - FieldElement::::one(); } x }) .collect(); FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); - let end_exemptions_evaluations = lde_points + let end_exemptions_evaluations: Vec<_> = lde_points .iter() .map(|point| { self.evaluate_end_exemptions_poly(point, trace_group_generator, trace_length) @@ -142,10 +140,10 @@ where #[allow(unstable_name_collisions)] fn evaluate_zerofier( &self, - eval_point: &CirclePoint, - trace_group_generator: &FieldElement, + eval_point: &CirclePoint, + trace_group_generator: &CirclePoint, trace_length: usize, - ) -> FieldElement { + ) -> FieldElement { // if let Some(exemptions_period) = self.exemptions_period() { // } else { @@ -154,9 +152,9 @@ where self.evaluate_end_exemptions_poly(eval_point, trace_group_generator, trace_length); let trace_log_2_size = trace_length.trailing_zeros(); - let mut x = eval_point.x; + let mut x = eval_point.x.clone(); for _ in 1..trace_log_2_size { - x = x.square().double() - FieldElement::one(); + x = x.square().double() - FieldElement::::one(); } x.inv().unwrap() * end_exemptions_evaluation diff --git a/provers/circle_stark/src/domain.rs b/provers/circle_stark/src/domain.rs index 948f03aff..818515966 100644 --- a/provers/circle_stark/src/domain.rs +++ b/provers/circle_stark/src/domain.rs @@ -1,20 +1,23 @@ use lambdaworks_math::{ - circle::{cosets::Coset, point::CirclePoint}, + circle::{ + cosets::Coset, + point::{CirclePoint, HasCircleParams}, + }, field::{element::FieldElement, traits::IsFFTField}, }; use super::air::AIR; -pub struct Domain { +pub struct Domain> { pub(crate) trace_length: usize, pub(crate) trace_log_2_length: u32, pub(crate) blowup_factor: usize, pub(crate) trace_coset_points: Vec>, pub(crate) lde_coset_points: Vec>, - pub(crate) trace_group_generator: FieldElement, + pub(crate) trace_group_generator: CirclePoint, } -impl Domain { +impl> Domain { pub fn new(air: &A) -> Self where A: AIR, diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs index 9be7b7183..11e7c92d4 100644 --- a/provers/circle_stark/src/frame.rs +++ b/provers/circle_stark/src/frame.rs @@ -27,9 +27,7 @@ impl Frame { let lde_steps = offsets .iter() - .map(|offset| { - let row = lde_trace.get_row(row + offset); - }) + .map(|offset| lde_trace.get_row(row + offset).to_vec()) .collect_vec(); Frame::new(lde_steps) diff --git a/provers/circle_stark/src/trace.rs b/provers/circle_stark/src/trace.rs index 636f28d05..a9eabd8f0 100644 --- a/provers/circle_stark/src/trace.rs +++ b/provers/circle_stark/src/trace.rs @@ -1,11 +1,16 @@ use crate::table::Table; use itertools::Itertools; -use lambdaworks_math::circle::point::CirclePoint; -use lambdaworks_math::fft::errors::FFTError; -use lambdaworks_math::field::traits::{IsField, IsSubFieldOf}; use lambdaworks_math::{ - field::{element::FieldElement, traits::IsFFTField}, - circle::polynomial::{interpolate_cfft, evaluate_point} + circle::{ + point::{CirclePoint, HasCircleParams}, + polynomial::{evaluate_point, interpolate_cfft}, + }, + fft::errors::FFTError, + field::{ + element::FieldElement, + traits::IsFFTField, + traits::{IsField, IsSubFieldOf}, + }, }; #[cfg(feature = "parallel")] use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -24,27 +29,16 @@ pub struct TraceTable { } impl TraceTable { - pub fn new( - data: Vec>, - num_columns: usize, - ) -> Self { + pub fn new(data: Vec>, num_columns: usize) -> Self { let table = Table::new(data, num_columns); - Self { - table, - num_columns, - } + Self { table, num_columns } } - pub fn from_columns( - columns: Vec>>, - ) -> Self { + pub fn from_columns(columns: Vec>>) -> Self { println!("COLUMNS LEN: {}", columns.len()); let num_columns = columns.len(); let table = Table::from_columns(columns); - Self { - table, - num_columns, - } + Self { table, num_columns } } pub fn empty() -> Self { @@ -161,11 +155,7 @@ impl LDETraceTable where F: IsField, { - pub fn new( - data: Vec>, - n_columns: usize, - blowup_factor: usize, - ) -> Self { + pub fn new(data: Vec>, n_columns: usize, blowup_factor: usize) -> Self { let table = Table::new(data, n_columns); Self { @@ -174,10 +164,7 @@ where } } - pub fn from_columns( - columns: Vec>>, - blowup_factor: usize, - ) -> Self { + pub fn from_columns(columns: Vec>>, blowup_factor: usize) -> Self { let table = Table::from_columns(columns); Self { @@ -216,11 +203,11 @@ pub fn get_trace_evaluations( group_generator: &CirclePoint, ) -> Table where - F: IsField, + F: IsField + HasCircleParams, { let evaluation_points = frame_offsets .iter() - .map(|offset| (group_generator * offset) + point) + .map(|offset| (group_generator * (*offset as u128)) + point) .collect_vec(); let evaluations = evaluation_points From 1a84697ec6933a50291f36feacdef090ad109930 Mon Sep 17 00:00:00 2001 From: Nicole Date: Tue, 5 Nov 2024 16:42:57 -0300 Subject: [PATCH 82/93] boundary constraints struct --- .../circle_stark/src/constraints/boundary.rs | 127 ++++++++---------- .../src/constraints/transition.rs | 1 + 2 files changed, 54 insertions(+), 74 deletions(-) diff --git a/provers/circle_stark/src/constraints/boundary.rs b/provers/circle_stark/src/constraints/boundary.rs index 50a4bbb98..c64a84c5f 100644 --- a/provers/circle_stark/src/constraints/boundary.rs +++ b/provers/circle_stark/src/constraints/boundary.rs @@ -1,6 +1,7 @@ use itertools::Itertools; use lambdaworks_math::{ - field::{element::FieldElement, traits::IsField}, + circle::point::CirclePoint, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field, traits::IsField}, polynomial::Polynomial, }; @@ -10,49 +11,23 @@ use lambdaworks_math::{ /// * col: The column of the trace where the constraint must hold /// * step: The step (or row) of the trace where the constraint must hold /// * value: The value the constraint must have in that column and step -pub struct BoundaryConstraint { +pub struct BoundaryConstraint { pub col: usize, pub step: usize, - pub value: FieldElement, - pub is_aux: bool, + pub value: FieldElement, } -impl BoundaryConstraint { - pub fn new_main(col: usize, step: usize, value: FieldElement) -> Self { - Self { - col, - step, - value, - is_aux: false, - } - } - - pub fn new_aux(col: usize, step: usize, value: FieldElement) -> Self { - Self { - col, - step, - value, - is_aux: true, - } - } - - /// Used for creating boundary constraints for a trace with only one column - pub fn new_simple_main(step: usize, value: FieldElement) -> Self { - Self { - col: 0, - step, - value, - is_aux: false, - } +impl BoundaryConstraint { + pub fn new(col: usize, step: usize, value: FieldElement) -> Self { + Self { col, step, value } } /// Used for creating boundary constraints for a trace with only one column - pub fn new_simple_aux(step: usize, value: FieldElement) -> Self { + pub fn new_simple(step: usize, value: FieldElement) -> Self { Self { col: 0, step, value, - is_aux: true, } } } @@ -60,20 +35,20 @@ impl BoundaryConstraint { /// Data structure that stores all the boundary constraints that must /// hold for the execution trace #[derive(Default, Debug)] -pub struct BoundaryConstraints { - pub constraints: Vec>, +pub struct BoundaryConstraints { + pub constraints: Vec, } -impl BoundaryConstraints { +impl BoundaryConstraints { #[allow(dead_code)] pub fn new() -> Self { Self { - constraints: Vec::>::new(), + constraints: Vec::::new(), } } /// To instantiate from a vector of BoundaryConstraint elements - pub fn from_constraints(constraints: Vec>) -> Self { + pub fn from_constraints(constraints: Vec) -> Self { Self { constraints } } @@ -86,6 +61,7 @@ impl BoundaryConstraints { .collect() } + /// Return all the steps where boundary constraints hold. pub fn steps_for_boundary(&self) -> Vec { self.constraints .iter() @@ -94,6 +70,7 @@ impl BoundaryConstraints { .collect() } + /// Return all the columns where boundary constraints hold. pub fn cols_for_boundary(&self) -> Vec { self.constraints .iter() @@ -102,28 +79,28 @@ impl BoundaryConstraints { .collect() } - /// Given the primitive root of some domain, returns the domain values corresponding + /// Given the group generator of some domain, returns for each column the domain values corresponding /// to the steps where the boundary conditions hold. This is useful when interpolating /// the boundary conditions, since we must know the x values pub fn generate_roots_of_unity( &self, - primitive_root: &FieldElement, + group_generator: &CirclePoint, cols_trace: &[usize], - ) -> Vec>> { + ) -> Vec>> { cols_trace .iter() .map(|i| { self.steps(*i) .into_iter() - .map(|s| primitive_root.pow(s)) - .collect::>>() + .map(|s| group_generator * (s as u128)) + .collect::>>() }) - .collect::>>>() + .collect::>>>() } /// For every trace column, give all the values the trace must be equal to in /// the steps where the boundary constraints hold - pub fn values(&self, cols_trace: &[usize]) -> Vec>> { + pub fn values(&self, cols_trace: &[usize]) -> Vec>> { cols_trace .iter() .map(|i| { @@ -136,24 +113,28 @@ impl BoundaryConstraints { .collect() } - /// Computes the zerofier of the boundary quotient. The result is the - /// multiplication of each binomial that evaluates to zero in the domain + /// Evaluate the zerofier of the boundary constraints for a column. The result is the + /// multiplication of each zerofier that evaluates to zero in the domain /// values where the boundary constraints must hold. /// /// Example: If there are boundary conditions in the third and fifth steps, - /// then the zerofier will be (x - w^3) * (x - w^5) - pub fn compute_zerofier( + /// then the zerofier will be f(x, y) = ( ((x, y) + p3.conjugate()).x - 1 ) * ( ((x, y) + p5.conjugate()).x - 1 ) + /// (eval_point + vanish_point.conjugate()).x - FieldElement::::one() + /// TODO: Optimize this function so we don't need to look up and indexes in the coset vector and clone its value. + pub fn evaluate_zerofier( &self, - primitive_root: &FieldElement, + trace_coset: &Vec>, col: usize, - ) -> Polynomial> { + eval_point: &CirclePoint, + ) -> FieldElement { self.steps(col).into_iter().fold( - Polynomial::new_monomial(FieldElement::::one(), 0), + FieldElement::::one(), |zerofier, step| { - let binomial = - Polynomial::new(&[-primitive_root.pow(step), FieldElement::::one()]); + let vanish_point = trace_coset[step].clone(); + let evaluation = (eval_point + vanish_point.conjugate()).x + - FieldElement::::one(); // TODO: Implement the MulAssign trait for Polynomials? - zerofier * binomial + zerofier * evaluation }, ) } @@ -161,8 +142,11 @@ impl BoundaryConstraints { #[cfg(test)] mod test { - use lambdaworks_math::field::{ - fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, traits::IsFFTField, + use lambdaworks_math::{ + circle::cosets::Coset, + field::{ + fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, traits::IsFFTField, + }, }; type PrimeField = Stark252PrimeField; @@ -170,30 +154,25 @@ mod test { #[test] fn zerofier_is_the_correct_one() { - let one = FieldElement::::one(); + let one = FieldElement::::one(); // Fibonacci constraints: // * a0 = 1 // * a1 = 1 // * a7 = 32 - let a0 = BoundaryConstraint::new_simple_main(0, one); - let a1 = BoundaryConstraint::new_simple_main(1, one); - let result = BoundaryConstraint::new_simple_main(7, FieldElement::::from(32)); - - let constraints = BoundaryConstraints::from_constraints(vec![a0, a1, result]); - - let primitive_root = PrimeField::get_primitive_root_of_unity(3).unwrap(); - - // P_0(x) = (x - 1) - let a0_zerofier = Polynomial::new(&[-one, one]); - // P_1(x) = (x - w^1) - let a1_zerofier = Polynomial::new(&[-primitive_root.pow(1u32), one]); - // P_res(x) = (x - w^7) - let res_zerofier = Polynomial::new(&[-primitive_root.pow(7u32), one]); - + let a0 = BoundaryConstraint::new_simple(0, one); + let a1 = BoundaryConstraint::new_simple(1, one); + let result = BoundaryConstraint::new_simple(7, FieldElement::::from(32)); + + let trace_coset = Coset::get_coset_points(&Coset::new_standard(3)); + let eval_point = CirclePoint::::GENERATOR * 2; + let a0_zerofier = (&eval_point + &trace_coset[0].clone().conjugate()).x - one; + let a1_zerofier = (&eval_point + &trace_coset[1].clone().conjugate()).x - one; + let res_zerofier = (&eval_point + &trace_coset[7].clone().conjugate()).x - one; let expected_zerofier = a0_zerofier * a1_zerofier * res_zerofier; - let zerofier = constraints.compute_zerofier(&primitive_root, 0); + let constraints = BoundaryConstraints::from_constraints(vec![a0, a1, result]); + let zerofier = constraints.evaluate_zerofier(&trace_coset, 0, &eval_point); assert_eq!(expected_zerofier, zerofier); } diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index fe5933c0d..286f5315d 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -87,6 +87,7 @@ where let period = self.period(); // This accumulates evaluations of the point at the zerofier at all the offsets positions. (1..=self.end_exemptions()) + // FIXME: I think this is wrong because exemption should be and element of the stndard coset instead of the group. (-nicole) .map(|exemption| trace_group_generator * ((trace_length - exemption * period) as u128)) .fold(one.clone(), |acc, vanishing_point| { acc * ((eval_point + vanishing_point.conjugate()).x - &one) From a691d42779e64ec5f14a704ce1d63e13bc567b52 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 5 Nov 2024 18:05:21 -0300 Subject: [PATCH 83/93] change circle polynomial to accept a reference --- math/src/circle/polynomial.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index 77f0fd76c..4cab7c0c3 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -75,7 +75,7 @@ pub fn interpolate_cfft( /// Note: This implementation uses a straightforward approach and is intended for testing purposes only. pub fn evaluate_point( coef: &Vec>, - point: CirclePoint, + point: &CirclePoint, ) -> FieldElement { let order = coef.len(); assert!( @@ -344,7 +344,7 @@ mod tests { let coset = Coset::new_standard(5); let coset_points = Coset::get_coset_points(&coset); - assert_eq!(evals[4], evaluate_point(&coeff, coset_points[4].clone())); + assert_eq!(evals[4], evaluate_point(&coeff, &coset_points[4].clone())); let new_coeff = interpolate_cfft(evals); From edd05ab095ac447f3ea251d980825fe7788be0fb Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 5 Nov 2024 18:05:47 -0300 Subject: [PATCH 84/93] remove generic IsField --- provers/circle_stark/src/air.rs | 35 +- .../circle_stark/src/constraints/evaluator.rs | 442 +++++++++--------- .../src/constraints/transition.rs | 32 +- provers/circle_stark/src/domain.rs | 16 +- provers/circle_stark/src/frame.rs | 16 +- provers/circle_stark/src/table.rs | 27 +- provers/circle_stark/src/trace.rs | 109 ++--- 7 files changed, 327 insertions(+), 350 deletions(-) diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs index 7dbe3f2ca..4a82f5e94 100644 --- a/provers/circle_stark/src/air.rs +++ b/provers/circle_stark/src/air.rs @@ -5,20 +5,15 @@ use crate::{ use lambdaworks_math::{ circle::point::{CirclePoint, HasCircleParams}, field::{ - element::FieldElement, - traits::{IsFFTField, IsField, IsSubFieldOf}, + element::FieldElement, fields::mersenne31::field::Mersenne31Field, traits::{IsFFTField, IsField, IsSubFieldOf} }, }; use std::collections::HashMap; -type ZerofierGroupKey = (usize, usize, Option, Option, usize); +// type ZerofierGroupKey = (usize, usize, Option, Option, usize); +type ZerofierGroupKey = (usize); + /// AIR is a representation of the Constraints pub trait AIR { - type Field: IsFFTField - + IsSubFieldOf - + Send - + Sync - + HasCircleParams; - type FieldExtension: IsField + Send + Sync + HasCircleParams; type PublicInputs; fn new(trace_length: usize, pub_inputs: &Self::PublicInputs) -> Self; @@ -33,17 +28,17 @@ pub trait AIR { /// `Self::Field`, since they are the evaluations of the main trace at the LDE domain. fn compute_transition_prover( &self, - frame: &Frame, - ) -> Vec> { + frame: &Frame, + ) -> Vec> { let mut evaluations = - vec![FieldElement::::zero(); self.num_transition_constraints()]; + vec![FieldElement::::zero(); self.num_transition_constraints()]; self.transition_constraints() .iter() .for_each(|c| c.evaluate(frame, &mut evaluations)); evaluations } - fn boundary_constraints(&self) -> BoundaryConstraints; + fn boundary_constraints(&self) -> BoundaryConstraints; /// The method called by the verifier to evaluate the transitions at the out of domain frame. /// In the case of the verifier, both main and auxiliary tables of the evaluation frame take @@ -53,8 +48,8 @@ pub trait AIR { /// `compute_transition_prover` should return the same values. fn compute_transition_verifier( &self, - frame: &Frame, - ) -> Vec>; + frame: &Frame, + ) -> Vec>; fn context(&self) -> &AirContext; @@ -64,7 +59,7 @@ pub trait AIR { 2 } - fn trace_group_generator(&self) -> CirclePoint { + fn trace_group_generator(&self) -> CirclePoint { let trace_length = self.trace_length(); let log_2_length = trace_length.trailing_zeros(); CirclePoint::get_generator_of_subgroup(log_2_length) @@ -78,14 +73,14 @@ pub trait AIR { fn transition_constraints( &self, - ) -> &Vec>>; + ) -> &Vec>; fn transition_zerofier_evaluations( &self, - domain: &Domain, - ) -> Vec>> { + domain: &Domain, + ) -> Vec>> { let mut evals = vec![Vec::new(); self.num_transition_constraints()]; - let mut zerofier_groups: HashMap>> = + let mut zerofier_groups: HashMap>> = HashMap::new(); self.transition_constraints().iter().for_each(|c| { diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs index 140a104ab..beb70cf10 100644 --- a/provers/circle_stark/src/constraints/evaluator.rs +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -1,221 +1,221 @@ -use super::boundary::BoundaryConstraints; -#[cfg(all(debug_assertions, not(feature = "parallel")))] -use crate::debug::check_boundary_polys_divisibility; -use crate::domain::Domain; -use crate::trace::LDETraceTable; -use crate::traits::AIR; -use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain}; -use itertools::Itertools; -#[cfg(all(debug_assertions, not(feature = "parallel")))] -use lambdaworks_math::polynomial::Polynomial; -use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes}; -#[cfg(feature = "parallel")] -use rayon::{ - iter::IndexedParallelIterator, - prelude::{IntoParallelIterator, ParallelIterator}, -}; -#[cfg(feature = "instruments")] -use std::time::Instant; - -pub struct ConstraintEvaluator { - boundary_constraints: BoundaryConstraints, -} -impl ConstraintEvaluator { - pub fn new(air: &A, rap_challenges: &[FieldElement]) -> Self { - let boundary_constraints = air.boundary_constraints(rap_challenges); - - Self { - boundary_constraints, - } - } - - pub(crate) fn evaluate( - &self, - air: &A, - lde_trace: &LDETraceTable, - domain: &Domain, - transition_coefficients: &[FieldElement], - boundary_coefficients: &[FieldElement], - rap_challenges: &[FieldElement], - ) -> Vec> - where - FieldElement: AsBytes + Send + Sync, - FieldElement: AsBytes + Send + Sync, - A: Send + Sync, - { - let boundary_constraints = &self.boundary_constraints; - let number_of_b_constraints = boundary_constraints.constraints.len(); - let boundary_zerofiers_inverse_evaluations: Vec>> = - boundary_constraints - .constraints - .iter() - .map(|bc| { - let point = &domain.trace_primitive_root.pow(bc.step as u64); - let mut evals = domain - .lde_roots_of_unity_coset - .iter() - .map(|v| v.clone() - point) - .collect::>>(); - FieldElement::inplace_batch_inverse(&mut evals).unwrap(); - evals - }) - .collect::>>>(); - - #[cfg(all(debug_assertions, not(feature = "parallel")))] - let boundary_polys: Vec>> = Vec::new(); - - #[cfg(feature = "instruments")] - let timer = Instant::now(); - - let lde_periodic_columns = air - .get_periodic_column_polynomials() - .iter() - .map(|poly| { - evaluate_polynomial_on_lde_domain( - poly, - domain.blowup_factor, - domain.interpolation_domain_size, - &domain.coset_offset, - ) - }) - .collect::>>, FFTError>>() - .unwrap(); - - #[cfg(feature = "instruments")] - println!( - " Evaluating periodic columns on lde: {:#?}", - timer.elapsed() - ); - - #[cfg(feature = "instruments")] - let timer = Instant::now(); - - let boundary_polys_evaluations = boundary_constraints - .constraints - .iter() - .map(|constraint| { - if constraint.is_aux { - (0..lde_trace.num_rows()) - .map(|row| { - let v = lde_trace.get_aux(row, constraint.col); - v - &constraint.value - }) - .collect_vec() - } else { - (0..lde_trace.num_rows()) - .map(|row| { - let v = lde_trace.get_main(row, constraint.col); - v - &constraint.value - }) - .collect_vec() - } - }) - .collect_vec(); - - #[cfg(feature = "instruments")] - println!(" Created boundary polynomials: {:#?}", timer.elapsed()); - #[cfg(feature = "instruments")] - let timer = Instant::now(); - - #[cfg(feature = "parallel")] - let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let boundary_eval_iter = 0..domain.lde_roots_of_unity_coset.len(); - - let boundary_evaluation: Vec<_> = boundary_eval_iter - .map(|domain_index| { - (0..number_of_b_constraints) - .zip(boundary_coefficients) - .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { - acc + &boundary_zerofiers_inverse_evaluations[constraint_index] - [domain_index] - * beta - * &boundary_polys_evaluations[constraint_index][domain_index] - }) - }) - .collect(); - - #[cfg(feature = "instruments")] - println!( - " Evaluated boundary polynomials on LDE: {:#?}", - timer.elapsed() - ); - - #[cfg(all(debug_assertions, not(feature = "parallel")))] - let boundary_zerofiers = Vec::new(); - - #[cfg(all(debug_assertions, not(feature = "parallel")))] - check_boundary_polys_divisibility(boundary_polys, boundary_zerofiers); - - #[cfg(all(debug_assertions, not(feature = "parallel")))] - let mut transition_evaluations = Vec::new(); - - #[cfg(feature = "instruments")] - let timer = Instant::now(); - let zerofiers_evals = air.transition_zerofier_evaluations(domain); - #[cfg(feature = "instruments")] - println!( - " Evaluated transition zerofiers: {:#?}", - timer.elapsed() - ); - - // Iterate over all LDE domain and compute the part of the composition polynomial - // related to the transition constraints and add it to the already computed part of the - // boundary constraints. - - #[cfg(feature = "instruments")] - let timer = Instant::now(); - let evaluations_t_iter = 0..domain.lde_roots_of_unity_coset.len(); - - #[cfg(feature = "parallel")] - let boundary_evaluation = boundary_evaluation.into_par_iter(); - #[cfg(feature = "parallel")] - let evaluations_t_iter = evaluations_t_iter.into_par_iter(); - - let evaluations_t = evaluations_t_iter - .zip(boundary_evaluation) - .map(|(i, boundary)| { - let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); - - let periodic_values: Vec<_> = lde_periodic_columns - .iter() - .map(|col| col[i].clone()) - .collect(); - - // Compute all the transition constraints at this point of the LDE domain. - let evaluations_transition = - air.compute_transition_prover(&frame, &periodic_values, rap_challenges); - - #[cfg(all(debug_assertions, not(feature = "parallel")))] - transition_evaluations.push(evaluations_transition.clone()); - - // Add each term of the transition constraints to the composition polynomial, including the zerofier, - // the challenge and the exemption polynomial if it is necessary. - let acc_transition = itertools::izip!( - evaluations_transition, - &zerofiers_evals, - transition_coefficients - ) - .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { - // Zerofier evaluations are cyclical, so we only calculate one cycle. - // This means that here we have to wrap around - // Ex: Suppose the full zerofier vector is Z = [1,2,3,1,2,3] - // we will instead have calculated Z' = [1,2,3] - // Now if you need Z[4] this is equal to Z'[1] - let wrapped_idx = i % zerof_eval.len(); - acc + &zerof_eval[wrapped_idx] * eval * beta - }); - - acc_transition + boundary - }) - .collect(); - - #[cfg(feature = "instruments")] - println!( - " Evaluated transitions and accumulated results: {:#?}", - timer.elapsed() - ); - - evaluations_t - } -} +// use super::boundary::BoundaryConstraints; +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// use crate::debug::check_boundary_polys_divisibility; +// use crate::domain::Domain; +// use crate::trace::LDETraceTable; +// use crate::traits::AIR; +// use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain}; +// use itertools::Itertools; +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// use lambdaworks_math::polynomial::Polynomial; +// use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes}; +// #[cfg(feature = "parallel")] +// use rayon::{ +// iter::IndexedParallelIterator, +// prelude::{IntoParallelIterator, ParallelIterator}, +// }; +// #[cfg(feature = "instruments")] +// use std::time::Instant; + +// pub struct ConstraintEvaluator { +// boundary_constraints: BoundaryConstraints, +// } +// impl ConstraintEvaluator { +// pub fn new(air: &A, rap_challenges: &[FieldElement]) -> Self { +// let boundary_constraints = air.boundary_constraints(rap_challenges); + +// Self { +// boundary_constraints, +// } +// } + +// pub(crate) fn evaluate( +// &self, +// air: &A, +// lde_trace: &LDETraceTable, +// domain: &Domain, +// transition_coefficients: &[FieldElement], +// boundary_coefficients: &[FieldElement], +// rap_challenges: &[FieldElement], +// ) -> Vec> +// where +// FieldElement: AsBytes + Send + Sync, +// FieldElement: AsBytes + Send + Sync, +// A: Send + Sync, +// { +// let boundary_constraints = &self.boundary_constraints; +// let number_of_b_constraints = boundary_constraints.constraints.len(); +// let boundary_zerofiers_inverse_evaluations: Vec>> = +// boundary_constraints +// .constraints +// .iter() +// .map(|bc| { +// let point = &domain.trace_primitive_root.pow(bc.step as u64); +// let mut evals = domain +// .lde_roots_of_unity_coset +// .iter() +// .map(|v| v.clone() - point) +// .collect::>>(); +// FieldElement::inplace_batch_inverse(&mut evals).unwrap(); +// evals +// }) +// .collect::>>>(); + +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// let boundary_polys: Vec>> = Vec::new(); + +// #[cfg(feature = "instruments")] +// let timer = Instant::now(); + +// let lde_periodic_columns = air +// .get_periodic_column_polynomials() +// .iter() +// .map(|poly| { +// evaluate_polynomial_on_lde_domain( +// poly, +// domain.blowup_factor, +// domain.interpolation_domain_size, +// &domain.coset_offset, +// ) +// }) +// .collect::>>, FFTError>>() +// .unwrap(); + +// #[cfg(feature = "instruments")] +// println!( +// " Evaluating periodic columns on lde: {:#?}", +// timer.elapsed() +// ); + +// #[cfg(feature = "instruments")] +// let timer = Instant::now(); + +// let boundary_polys_evaluations = boundary_constraints +// .constraints +// .iter() +// .map(|constraint| { +// if constraint.is_aux { +// (0..lde_trace.num_rows()) +// .map(|row| { +// let v = lde_trace.get_aux(row, constraint.col); +// v - &constraint.value +// }) +// .collect_vec() +// } else { +// (0..lde_trace.num_rows()) +// .map(|row| { +// let v = lde_trace.get_main(row, constraint.col); +// v - &constraint.value +// }) +// .collect_vec() +// } +// }) +// .collect_vec(); + +// #[cfg(feature = "instruments")] +// println!(" Created boundary polynomials: {:#?}", timer.elapsed()); +// #[cfg(feature = "instruments")] +// let timer = Instant::now(); + +// #[cfg(feature = "parallel")] +// let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter(); +// #[cfg(not(feature = "parallel"))] +// let boundary_eval_iter = 0..domain.lde_roots_of_unity_coset.len(); + +// let boundary_evaluation: Vec<_> = boundary_eval_iter +// .map(|domain_index| { +// (0..number_of_b_constraints) +// .zip(boundary_coefficients) +// .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { +// acc + &boundary_zerofiers_inverse_evaluations[constraint_index] +// [domain_index] +// * beta +// * &boundary_polys_evaluations[constraint_index][domain_index] +// }) +// }) +// .collect(); + +// #[cfg(feature = "instruments")] +// println!( +// " Evaluated boundary polynomials on LDE: {:#?}", +// timer.elapsed() +// ); + +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// let boundary_zerofiers = Vec::new(); + +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// check_boundary_polys_divisibility(boundary_polys, boundary_zerofiers); + +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// let mut transition_evaluations = Vec::new(); + +// #[cfg(feature = "instruments")] +// let timer = Instant::now(); +// let zerofiers_evals = air.transition_zerofier_evaluations(domain); +// #[cfg(feature = "instruments")] +// println!( +// " Evaluated transition zerofiers: {:#?}", +// timer.elapsed() +// ); + +// // Iterate over all LDE domain and compute the part of the composition polynomial +// // related to the transition constraints and add it to the already computed part of the +// // boundary constraints. + +// #[cfg(feature = "instruments")] +// let timer = Instant::now(); +// let evaluations_t_iter = 0..domain.lde_roots_of_unity_coset.len(); + +// #[cfg(feature = "parallel")] +// let boundary_evaluation = boundary_evaluation.into_par_iter(); +// #[cfg(feature = "parallel")] +// let evaluations_t_iter = evaluations_t_iter.into_par_iter(); + +// let evaluations_t = evaluations_t_iter +// .zip(boundary_evaluation) +// .map(|(i, boundary)| { +// let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); + +// let periodic_values: Vec<_> = lde_periodic_columns +// .iter() +// .map(|col| col[i].clone()) +// .collect(); + +// // Compute all the transition constraints at this point of the LDE domain. +// let evaluations_transition = +// air.compute_transition_prover(&frame, &periodic_values, rap_challenges); + +// #[cfg(all(debug_assertions, not(feature = "parallel")))] +// transition_evaluations.push(evaluations_transition.clone()); + +// // Add each term of the transition constraints to the composition polynomial, including the zerofier, +// // the challenge and the exemption polynomial if it is necessary. +// let acc_transition = itertools::izip!( +// evaluations_transition, +// &zerofiers_evals, +// transition_coefficients +// ) +// .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { +// // Zerofier evaluations are cyclical, so we only calculate one cycle. +// // This means that here we have to wrap around +// // Ex: Suppose the full zerofier vector is Z = [1,2,3,1,2,3] +// // we will instead have calculated Z' = [1,2,3] +// // Now if you need Z[4] this is equal to Z'[1] +// let wrapped_idx = i % zerof_eval.len(); +// acc + &zerof_eval[wrapped_idx] * eval * beta +// }); + +// acc_transition + boundary +// }) +// .collect(); + +// #[cfg(feature = "instruments")] +// println!( +// " Evaluated transitions and accumulated results: {:#?}", +// timer.elapsed() +// ); + +// evaluations_t +// } +// } diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index fe5933c0d..277156623 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -1,14 +1,12 @@ use crate::domain::Domain; use crate::frame::Frame; -use lambdaworks_math::circle::point::{CirclePoint, HasCircleParams}; +use lambdaworks_math::circle::point::CirclePoint; use lambdaworks_math::field::element::FieldElement; -use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; +use lambdaworks_math::field::fields::mersenne31::field::Mersenne31Field; +use lambdaworks_math::field::traits::IsField; /// TransitionConstraint represents the behaviour that a transition constraint /// over the computation that wants to be proven must comply with. -pub trait TransitionConstraint: Send + Sync -where - F: IsSubFieldOf + IsFFTField + Send + Sync + HasCircleParams, - E: IsField + Send + Sync + HasCircleParams, +pub trait TransitionConstraint { /// The degree of the constraint interpreting it as a multivariate polynomial. fn degree(&self) -> usize; @@ -26,7 +24,7 @@ where /// the evaluation. /// Once computed, the evaluation should be inserted in the `transition_evaluations` /// vector, in the index corresponding to the constraint as given by `constraint_idx()`. - fn evaluate(&self, frame: &Frame, transition_evaluations: &mut [FieldElement]); + fn evaluate(&self, frame: &Frame, transition_evaluations: &mut [FieldElement]); /// The periodicity the constraint is applied over the trace. /// @@ -75,12 +73,12 @@ where /// Evaluate the `eval_point` in the polynomial that vanishes in all the exemptions points. fn evaluate_end_exemptions_poly( &self, - eval_point: &CirclePoint, + eval_point: &CirclePoint, // `trace_group_generator` can be calculated with `trace_length` but it is better to precompute it - trace_group_generator: &CirclePoint, + trace_group_generator: &CirclePoint, trace_length: usize, - ) -> FieldElement { - let one = FieldElement::::one(); + ) -> FieldElement { + let one = FieldElement::::one(); if self.end_exemptions() == 0 { return one; } @@ -97,7 +95,7 @@ where /// TODO: See if we can evaluate using cfft. /// TODO: See if we can optimize computing only some evaluations and cycle them as in regular stark. #[allow(unstable_name_collisions)] - fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { + fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { let blowup_factor = domain.blowup_factor; let trace_length = domain.trace_length; let trace_log_2_size = trace_length.trailing_zeros(); @@ -116,7 +114,7 @@ where // TODO: Is there a way to avoid this clone()? let mut x = point.x.clone(); for _ in 1..trace_log_2_size { - x = x.square().double() - FieldElement::::one(); + x = x.square().double() - FieldElement::::one(); } x }) @@ -140,10 +138,10 @@ where #[allow(unstable_name_collisions)] fn evaluate_zerofier( &self, - eval_point: &CirclePoint, - trace_group_generator: &CirclePoint, + eval_point: &CirclePoint, + trace_group_generator: &CirclePoint, trace_length: usize, - ) -> FieldElement { + ) -> FieldElement { // if let Some(exemptions_period) = self.exemptions_period() { // } else { @@ -154,7 +152,7 @@ where let trace_log_2_size = trace_length.trailing_zeros(); let mut x = eval_point.x.clone(); for _ in 1..trace_log_2_size { - x = x.square().double() - FieldElement::::one(); + x = x.square().double() - FieldElement::::one(); } x.inv().unwrap() * end_exemptions_evaluation diff --git a/provers/circle_stark/src/domain.rs b/provers/circle_stark/src/domain.rs index 818515966..a684fddb7 100644 --- a/provers/circle_stark/src/domain.rs +++ b/provers/circle_stark/src/domain.rs @@ -1,26 +1,26 @@ use lambdaworks_math::{ circle::{ cosets::Coset, - point::{CirclePoint, HasCircleParams}, + point::CirclePoint, }, - field::{element::FieldElement, traits::IsFFTField}, + field::fields::mersenne31::field::Mersenne31Field, }; use super::air::AIR; -pub struct Domain> { +pub struct Domain { pub(crate) trace_length: usize, pub(crate) trace_log_2_length: u32, pub(crate) blowup_factor: usize, - pub(crate) trace_coset_points: Vec>, - pub(crate) lde_coset_points: Vec>, - pub(crate) trace_group_generator: CirclePoint, + pub(crate) trace_coset_points: Vec>, + pub(crate) lde_coset_points: Vec>, + pub(crate) trace_group_generator: CirclePoint, } -impl> Domain { +impl Domain { pub fn new(air: &A) -> Self where - A: AIR, + A: AIR, { // Initial definitions let trace_length = air.trace_length(); diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs index 11e7c92d4..8752420bf 100644 --- a/provers/circle_stark/src/frame.rs +++ b/provers/circle_stark/src/frame.rs @@ -1,28 +1,26 @@ use crate::trace::LDETraceTable; use itertools::Itertools; -use lambdaworks_math::field::{element::FieldElement, traits::IsField}; +use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; /// A frame represents a collection of trace steps. /// The collected steps are all the necessary steps for /// all transition costraints over a trace to be evaluated. #[derive(Clone, Debug, PartialEq)] -pub struct Frame -where - F: IsField, +pub struct Frame { - steps: Vec>>, + steps: Vec>>, } -impl Frame { - pub fn new(steps: Vec>>) -> Self { +impl Frame { + pub fn new(steps: Vec>>) -> Self { Self { steps } } - pub fn get_evaluation_step(&self, step: usize) -> &Vec> { + pub fn get_evaluation_step(&self, step: usize) -> &Vec> { &self.steps[step] } - pub fn read_from_lde(lde_trace: &LDETraceTable, row: usize, offsets: &[usize]) -> Self { + pub fn read_from_lde(lde_trace: &LDETraceTable, row: usize, offsets: &[usize]) -> Self { let num_rows = lde_trace.num_rows(); let lde_steps = offsets diff --git a/provers/circle_stark/src/table.rs b/provers/circle_stark/src/table.rs index e6e45a127..18133d112 100644 --- a/provers/circle_stark/src/table.rs +++ b/provers/circle_stark/src/table.rs @@ -1,6 +1,5 @@ use lambdaworks_math::field::{ - element::FieldElement, - traits::IsField, + element::FieldElement, fields::mersenne31::field::Mersenne31Field, }; /// A two-dimensional Table holding field elements, arranged in a row-major order. @@ -9,16 +8,16 @@ use lambdaworks_math::field::{ /// Since this struct is a representation of a two-dimensional table, all rows should have the same /// length. #[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct Table { - pub data: Vec>, +pub struct Table { + pub data: Vec>, pub width: usize, pub height: usize, } -impl<'t, F: IsField> Table { +impl Table { /// Crates a new Table instance from a one-dimensional array in row major order /// and the intended width of the table. - pub fn new(data: Vec>, width: usize) -> Self { + pub fn new(data: Vec>, width: usize) -> Self { // Check if the intented width is 0, used for creating an empty table. if width == 0 { return Self { @@ -40,7 +39,7 @@ impl<'t, F: IsField> Table { } /// Creates a Table instance from a vector of the intended columns. - pub fn from_columns(columns: Vec>>) -> Self { + pub fn from_columns(columns: Vec>>) -> Self { if columns.is_empty() { return Self::new(Vec::new(), 0); } @@ -62,18 +61,18 @@ impl<'t, F: IsField> Table { } /// Returns a vector of vectors of field elements representing the table rows - pub fn rows(&self) -> Vec>> { + pub fn rows(&self) -> Vec>> { self.data.chunks(self.width).map(|r| r.to_vec()).collect() } /// Given a row index, returns a reference to that row as a slice of field elements. - pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { let row_offset = row_idx * self.width; &self.data[row_offset..row_offset + self.width] } /// Given a row index, returns a mutable reference to that row as a slice of field elements. - pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { + pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { let n_cols = self.width; let row_offset = row_idx * n_cols; &mut self.data[row_offset..row_offset + n_cols] @@ -81,20 +80,20 @@ impl<'t, F: IsField> Table { /// Given a slice of field elements representing a row, appends it to /// the end of the table. - pub fn append_row(&mut self, row: &[FieldElement]) { + pub fn append_row(&mut self, row: &[FieldElement]) { debug_assert_eq!(row.len(), self.width); self.data.extend_from_slice(row); self.height += 1 } /// Returns a reference to the last row of the table - pub fn last_row(&self) -> &[FieldElement] { + pub fn last_row(&self) -> &[FieldElement] { self.get_row(self.height - 1) } /// Returns a vector of vectors of field elements representing the table /// columns - pub fn columns(&self) -> Vec>> { + pub fn columns(&self) -> Vec>> { (0..self.width) .map(|col_idx| { (0..self.height) @@ -105,7 +104,7 @@ impl<'t, F: IsField> Table { } /// Given row and column indexes, returns the stored field element in that position of the table. - pub fn get(&self, row: usize, col: usize) -> &FieldElement { + pub fn get(&self, row: usize, col: usize) -> &FieldElement { let idx = row * self.width + col; &self.data[idx] } diff --git a/provers/circle_stark/src/trace.rs b/provers/circle_stark/src/trace.rs index a9eabd8f0..57c918ec4 100644 --- a/provers/circle_stark/src/trace.rs +++ b/provers/circle_stark/src/trace.rs @@ -2,14 +2,12 @@ use crate::table::Table; use itertools::Itertools; use lambdaworks_math::{ circle::{ - point::{CirclePoint, HasCircleParams}, + point::CirclePoint, polynomial::{evaluate_point, interpolate_cfft}, }, fft::errors::FFTError, field::{ - element::FieldElement, - traits::IsFFTField, - traits::{IsField, IsSubFieldOf}, + element::FieldElement, fields::mersenne31::field::Mersenne31Field, }, }; #[cfg(feature = "parallel")] @@ -23,18 +21,18 @@ use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; /// STARK protocol, such as the step size (number of consecutive rows of the table) /// of the computation being proven. #[derive(Clone, Default, Debug, PartialEq, Eq)] -pub struct TraceTable { - pub table: Table, +pub struct TraceTable { + pub table: Table, pub num_columns: usize, } -impl TraceTable { - pub fn new(data: Vec>, num_columns: usize) -> Self { +impl TraceTable { + pub fn new(data: Vec>, num_columns: usize) -> Self { let table = Table::new(data, num_columns); Self { table, num_columns } } - pub fn from_columns(columns: Vec>>) -> Self { + pub fn from_columns(columns: Vec>>) -> Self { println!("COLUMNS LEN: {}", columns.len()); let num_columns = columns.len(); let table = Table::from_columns(columns); @@ -57,23 +55,23 @@ impl TraceTable { self.table.width } - pub fn rows(&self) -> Vec>> { + pub fn rows(&self) -> Vec>> { self.table.rows() } - pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { self.table.get_row(row_idx) } - pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { + pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { self.table.get_row_mut(row_idx) } - pub fn last_row(&self) -> &[FieldElement] { + pub fn last_row(&self) -> &[FieldElement] { self.get_row(self.n_rows() - 1) } - pub fn columns(&self) -> Vec>> { + pub fn columns(&self) -> Vec>> { self.table.columns() } @@ -83,7 +81,7 @@ impl TraceTable { /// The particular way they are merged is not really important since this function is used to /// aggreagate values distributed across various columns with no importance on their ordering, /// such as to sort them. - pub fn merge_columns(&self, column_indexes: &[usize]) -> Vec> { + pub fn merge_columns(&self, column_indexes: &[usize]) -> Vec> { let mut data = Vec::with_capacity(self.n_rows() * column_indexes.len()); for row_index in 0..self.n_rows() { for column in column_indexes { @@ -93,20 +91,16 @@ impl TraceTable { data } - pub fn compute_trace_polys(&self) -> Vec>> - where - S: IsFFTField + IsSubFieldOf, - FieldElement: Send + Sync, + pub fn compute_trace_polys(&self) -> Vec>> { let columns = self.columns(); #[cfg(feature = "parallel")] let iter = columns.par_iter(); #[cfg(not(feature = "parallel"))] let iter = columns.iter(); - - iter.map(|col| interpolate_cfft(col)) - .collect::>>, FFTError>>() - .unwrap() + // FIX: Replace the .to_vec() + iter.map(|col| interpolate_cfft(col.to_vec())) + .collect::>>>() } /// Given the padding length, appends the last row of the trace table @@ -126,7 +120,7 @@ impl TraceTable { /// The row_idx passed as argument may be greater than the max row index by 1. In this case, /// last row of the trace is cloned, and the value is set in that cloned row. Then, the row is /// appended to the end of the trace. - pub fn set_or_extend(&mut self, row_idx: usize, col_idx: usize, value: &FieldElement) { + pub fn set_or_extend(&mut self, row_idx: usize, col_idx: usize, value: &FieldElement) { debug_assert!(col_idx < self.n_cols()); // NOTE: This is not very nice, but for how the function is being used at the moment, // the passed `row_idx` should never be greater than `self.n_rows() + 1`. This is just @@ -143,19 +137,15 @@ impl TraceTable { } } } -pub struct LDETraceTable -where - F: IsField, +pub struct LDETraceTable { - pub(crate) table: Table, + pub(crate) table: Table, pub(crate) blowup_factor: usize, } -impl LDETraceTable -where - F: IsField, +impl LDETraceTable { - pub fn new(data: Vec>, n_columns: usize, blowup_factor: usize) -> Self { + pub fn new(data: Vec>, n_columns: usize, blowup_factor: usize) -> Self { let table = Table::new(data, n_columns); Self { @@ -164,7 +154,7 @@ where } } - pub fn from_columns(columns: Vec>>, blowup_factor: usize) -> Self { + pub fn from_columns(columns: Vec>>, blowup_factor: usize) -> Self { let table = Table::from_columns(columns); Self { @@ -181,11 +171,11 @@ where self.table.height } - pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { self.table.get_row(row_idx) } - pub fn get_table(&self, row: usize, col: usize) -> &FieldElement { + pub fn get_table(&self, row: usize, col: usize) -> &FieldElement { self.table.get(row, col) } } @@ -196,35 +186,32 @@ where /// compute a transition. /// Example: For a simple Fibonacci computation, if t(x) is the trace polynomial of /// the computation, this will output evaluations t(x), t(g * x), t(g^2 * z). -pub fn get_trace_evaluations( - trace_polys: &[Vec>], - point: &CirclePoint, +pub fn get_trace_evaluations( + trace_polys: &[Vec>], + point: &CirclePoint, frame_offsets: &[usize], - group_generator: &CirclePoint, -) -> Table -where - F: IsField + HasCircleParams, + group_generator: &CirclePoint, +) -> Table { let evaluation_points = frame_offsets .iter() .map(|offset| (group_generator * (*offset as u128)) + point) .collect_vec(); - let evaluations = evaluation_points + let evaluations: Vec<_> = evaluation_points .iter() - .map(|eval_point| { + .flat_map(|eval_point| { trace_polys .iter() .map(|poly| evaluate_point(poly, eval_point)) - .collect_vec() }) - .collect_vec(); + .collect(); let table_width = trace_polys.len(); Table::new(evaluations, table_width) } -pub fn columns2rows(columns: Vec>>) -> Vec>> { +pub fn columns2rows(columns: Vec>>) -> Vec>> { let num_rows = columns[0].len(); let num_cols = columns.len(); @@ -237,20 +224,20 @@ pub fn columns2rows(columns: Vec>>) -> Vec; +// #[cfg(test)] +// mod test { +// use super::TraceTable; +// use lambdaworks_math::field::{element::FieldElement, fields::u64_prime_field::F17}; +// type FE = FieldElement; - #[test] - fn test_cols() { - let col_1 = vec![FE::from(1), FE::from(2), FE::from(5), FE::from(13)]; - let col_2 = vec![FE::from(1), FE::from(3), FE::from(8), FE::from(21)]; +// #[test] +// fn test_cols() { +// let col_1 = vec![FE::from(1), FE::from(2), FE::from(5), FE::from(13)]; +// let col_2 = vec![FE::from(1), FE::from(3), FE::from(8), FE::from(21)]; - let trace_table = TraceTable::from_columns(vec![col_1.clone(), col_2.clone()]); - let res_cols = trace_table.columns(); +// let trace_table = TraceTable::from_columns(vec![col_1.clone(), col_2.clone()]); +// let res_cols = trace_table.columns(); - assert_eq!(res_cols, vec![col_1, col_2]); - } -} +// assert_eq!(res_cols, vec![col_1, col_2]); +// } +// } From d89260ede739d32ad986de12e2fd4473b8f13fee Mon Sep 17 00:00:00 2001 From: Nicole Date: Tue, 5 Nov 2024 19:14:08 -0300 Subject: [PATCH 85/93] fix example simple_fibonacci --- provers/circle_stark/src/air.rs | 20 ++---- .../{ => src}/examples/simple_fibonacci.rs | 72 ++++--------------- 2 files changed, 22 insertions(+), 70 deletions(-) rename provers/circle_stark/{ => src}/examples/simple_fibonacci.rs (74%) diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs index 4a82f5e94..201194f4f 100644 --- a/provers/circle_stark/src/air.rs +++ b/provers/circle_stark/src/air.rs @@ -5,7 +5,9 @@ use crate::{ use lambdaworks_math::{ circle::point::{CirclePoint, HasCircleParams}, field::{ - element::FieldElement, fields::mersenne31::field::Mersenne31Field, traits::{IsFFTField, IsField, IsSubFieldOf} + element::FieldElement, + fields::mersenne31::field::Mersenne31Field, + traits::{IsFFTField, IsField, IsSubFieldOf}, }, }; use std::collections::HashMap; @@ -26,10 +28,7 @@ pub trait AIR { /// The method called by the prover to evaluate the transitions corresponding to an evaluation frame. /// In the case of the prover, the main evaluation table of the frame takes values in /// `Self::Field`, since they are the evaluations of the main trace at the LDE domain. - fn compute_transition_prover( - &self, - frame: &Frame, - ) -> Vec> { + fn compute_transition_prover(&self, frame: &Frame) -> Vec> { let mut evaluations = vec![FieldElement::::zero(); self.num_transition_constraints()]; self.transition_constraints() @@ -38,7 +37,7 @@ pub trait AIR { evaluations } - fn boundary_constraints(&self) -> BoundaryConstraints; + fn boundary_constraints(&self) -> BoundaryConstraints; /// The method called by the verifier to evaluate the transitions at the out of domain frame. /// In the case of the verifier, both main and auxiliary tables of the evaluation frame take @@ -46,10 +45,7 @@ pub trait AIR { /// at the out of domain challenge. /// In case `Self::Field` coincides with `Self::FieldExtension`, this method and /// `compute_transition_prover` should return the same values. - fn compute_transition_verifier( - &self, - frame: &Frame, - ) -> Vec>; + fn compute_transition_verifier(&self, frame: &Frame) -> Vec>; fn context(&self) -> &AirContext; @@ -71,9 +67,7 @@ pub trait AIR { fn pub_inputs(&self) -> &Self::PublicInputs; - fn transition_constraints( - &self, - ) -> &Vec>; + fn transition_constraints(&self) -> &Vec>; fn transition_zerofier_evaluations( &self, diff --git a/provers/circle_stark/examples/simple_fibonacci.rs b/provers/circle_stark/src/examples/simple_fibonacci.rs similarity index 74% rename from provers/circle_stark/examples/simple_fibonacci.rs rename to provers/circle_stark/src/examples/simple_fibonacci.rs index 204aa938c..ff241cba2 100644 --- a/provers/circle_stark/examples/simple_fibonacci.rs +++ b/provers/circle_stark/src/examples/simple_fibonacci.rs @@ -10,21 +10,16 @@ use crate::{ traits::AIR, }; use lambdaworks_math::field::{element::FieldElement, traits::IsFFTField}; -use std::marker::PhantomData; - +// use std::marker::PhantomData; #[derive(Clone)] struct FibConstraint { - phantom: PhantomData, + //phantom: PhantomData, } - -impl FibConstraint { - pub fn new() -> Self { - Self { - phantom: PhantomData, - } - } -} - +// impl FibConstraint { +// pub fn new() -> Self { +// Self {} +// } +// } impl TransitionConstraint for FibConstraint where F: IsFFTField + Send + Sync, @@ -32,15 +27,12 @@ where fn degree(&self) -> usize { 1 } - fn constraint_idx(&self) -> usize { 0 } - fn end_exemptions(&self) -> usize { 2 } - fn evaluate( &self, frame: &Frame, @@ -51,17 +43,13 @@ where let first_step = frame.get_evaluation_step(0); let second_step = frame.get_evaluation_step(1); let third_step = frame.get_evaluation_step(2); - - let a0 = first_step.get_main_evaluation_element(0, 0); - let a1 = second_step.get_main_evaluation_element(0, 0); - let a2 = third_step.get_main_evaluation_element(0, 0); - + let a0 = first_step[0]; + let a1 = second_step[0]; + let a2 = third_step[0]; let res = a2 - a1 - a0; - transition_evaluations[self.constraint_idx()] = res; } } - pub struct FibonacciAIR where F: IsFFTField, @@ -71,7 +59,6 @@ where pub_inputs: FibonacciPublicInputs, constraints: Vec>>, } - #[derive(Clone, Debug)] pub struct FibonacciPublicInputs where @@ -80,7 +67,6 @@ where pub a0: FieldElement, pub a1: FieldElement, } - impl AIR for FibonacciAIR where F: IsFFTField + Send + Sync + 'static, @@ -88,25 +74,15 @@ where type Field = F; type FieldExtension = F; type PublicInputs = FibonacciPublicInputs; - - const STEP_SIZE: usize = 1; - - fn new( - trace_length: usize, - pub_inputs: &Self::PublicInputs, - proof_options: &ProofOptions, - ) -> Self { + fn new(trace_length: usize, pub_inputs: &Self::PublicInputs) -> Self { let constraints: Vec>> = vec![Box::new(FibConstraint::new())]; - let context = AirContext { - proof_options: proof_options.clone(), trace_columns: 1, transition_exemptions: vec![2], transition_offsets: vec![0, 1, 2], num_transition_constraints: constraints.len(), }; - Self { pub_inputs: pub_inputs.clone(), context, @@ -114,63 +90,45 @@ where constraints, } } - fn composition_poly_degree_bound(&self) -> usize { self.trace_length() } - fn transition_constraints(&self) -> &Vec>> { &self.constraints } - - fn boundary_constraints( - &self, - _rap_challenges: &[FieldElement], - ) -> BoundaryConstraints { - let a0 = BoundaryConstraint::new_simple_main(0, self.pub_inputs.a0.clone()); - let a1 = BoundaryConstraint::new_simple_main(1, self.pub_inputs.a1.clone()); - + fn boundary_constraints(&self) -> BoundaryConstraints { + let a0 = BoundaryConstraint::new_simple(0, self.pub_inputs.a0.clone()); + let a1 = BoundaryConstraint::new_simple(1, self.pub_inputs.a1.clone()); BoundaryConstraints::from_constraints(vec![a0, a1]) } - fn context(&self) -> &AirContext { &self.context } - fn trace_length(&self) -> usize { self.trace_length } - fn trace_layout(&self) -> (usize, usize) { (1, 0) } - fn pub_inputs(&self) -> &Self::PublicInputs { &self.pub_inputs } - fn compute_transition_verifier( &self, frame: &Frame, - periodic_values: &[FieldElement], - rap_challenges: &[FieldElement], ) -> Vec> { - self.compute_transition_prover(frame, periodic_values, rap_challenges) + self.compute_transition_prover(frame) } } - pub fn fibonacci_trace( initial_values: [FieldElement; 2], trace_length: usize, ) -> TraceTable { let mut ret: Vec> = vec![]; - ret.push(initial_values[0].clone()); ret.push(initial_values[1].clone()); - for i in 2..(trace_length) { ret.push(ret[i - 1].clone() + ret[i - 2].clone()); } - TraceTable::from_columns(vec![ret], 1, 1) } From 2197d2f01c86d74bb65406719a41995388ea9801 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Wed, 6 Nov 2024 11:08:10 -0300 Subject: [PATCH 86/93] add integration test --- provers/circle_stark/Cargo.toml | 3 + .../src/constraints/transition.rs | 1 - provers/circle_stark/src/examples/mod.rs | 1 + .../src/examples/simple_fibonacci.rs | 79 ++++++++----------- provers/circle_stark/src/lib.rs | 6 ++ provers/circle_stark/src/tests/integration.rs | 38 +++++++++ provers/circle_stark/src/tests/mod.rs | 1 + 7 files changed, 80 insertions(+), 49 deletions(-) create mode 100644 provers/circle_stark/src/examples/mod.rs create mode 100644 provers/circle_stark/src/tests/integration.rs create mode 100644 provers/circle_stark/src/tests/mod.rs diff --git a/provers/circle_stark/Cargo.toml b/provers/circle_stark/Cargo.toml index f167233e8..27a7a54c0 100644 --- a/provers/circle_stark/Cargo.toml +++ b/provers/circle_stark/Cargo.toml @@ -15,3 +15,6 @@ lambdaworks-crypto = { workspace = true, features = ["std", "serde"] } thiserror = "1.0.38" itertools = "0.11.0" serde = { version = "1.0", features = ["derive"] } + +[dev-dependencies] +test-log = { version = "0.2.11", features = ["log"] } diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index 8d51e4d5c..25d9da946 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -3,7 +3,6 @@ use crate::frame::Frame; use lambdaworks_math::circle::point::CirclePoint; use lambdaworks_math::field::element::FieldElement; use lambdaworks_math::field::fields::mersenne31::field::Mersenne31Field; -use lambdaworks_math::field::traits::IsField; /// TransitionConstraint represents the behaviour that a transition constraint /// over the computation that wants to be proven must comply with. pub trait TransitionConstraint diff --git a/provers/circle_stark/src/examples/mod.rs b/provers/circle_stark/src/examples/mod.rs new file mode 100644 index 000000000..c8bec26b5 --- /dev/null +++ b/provers/circle_stark/src/examples/mod.rs @@ -0,0 +1 @@ +pub mod simple_fibonacci; \ No newline at end of file diff --git a/provers/circle_stark/src/examples/simple_fibonacci.rs b/provers/circle_stark/src/examples/simple_fibonacci.rs index ff241cba2..ae2d35fde 100644 --- a/provers/circle_stark/src/examples/simple_fibonacci.rs +++ b/provers/circle_stark/src/examples/simple_fibonacci.rs @@ -3,27 +3,22 @@ use crate::{ boundary::{BoundaryConstraint, BoundaryConstraints}, transition::TransitionConstraint, }, - context::AirContext, + air_context::AirContext, frame::Frame, - proof::options::ProofOptions, trace::TraceTable, - traits::AIR, + air::AIR, }; -use lambdaworks_math::field::{element::FieldElement, traits::IsFFTField}; +use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field, traits::IsFFTField}; // use std::marker::PhantomData; #[derive(Clone)] -struct FibConstraint { - //phantom: PhantomData, +struct FibConstraint; + +impl FibConstraint { + pub fn new() -> Self { + Self {} + } } -// impl FibConstraint { -// pub fn new() -> Self { -// Self {} -// } -// } -impl TransitionConstraint for FibConstraint -where - F: IsFFTField + Send + Sync, -{ +impl TransitionConstraint for FibConstraint { fn degree(&self) -> usize { 1 } @@ -35,10 +30,8 @@ where } fn evaluate( &self, - frame: &Frame, - transition_evaluations: &mut [FieldElement], - _periodic_values: &[FieldElement], - _rap_challenges: &[FieldElement], + frame: &Frame, + transition_evaluations: &mut [FieldElement], ) { let first_step = frame.get_evaluation_step(0); let second_step = frame.get_evaluation_step(1); @@ -50,32 +43,22 @@ where transition_evaluations[self.constraint_idx()] = res; } } -pub struct FibonacciAIR -where - F: IsFFTField, -{ +pub struct FibonacciAIR { context: AirContext, trace_length: usize, - pub_inputs: FibonacciPublicInputs, - constraints: Vec>>, + pub_inputs: FibonacciPublicInputs, + constraints: Vec>, } #[derive(Clone, Debug)] -pub struct FibonacciPublicInputs -where - F: IsFFTField, -{ - pub a0: FieldElement, - pub a1: FieldElement, +pub struct FibonacciPublicInputs{ + pub a0: FieldElement, + pub a1: FieldElement, } -impl AIR for FibonacciAIR -where - F: IsFFTField + Send + Sync + 'static, +impl AIR for FibonacciAIR { - type Field = F; - type FieldExtension = F; - type PublicInputs = FibonacciPublicInputs; + type PublicInputs = FibonacciPublicInputs; fn new(trace_length: usize, pub_inputs: &Self::PublicInputs) -> Self { - let constraints: Vec>> = + let constraints: Vec> = vec![Box::new(FibConstraint::new())]; let context = AirContext { trace_columns: 1, @@ -93,7 +76,7 @@ where fn composition_poly_degree_bound(&self) -> usize { self.trace_length() } - fn transition_constraints(&self) -> &Vec>> { + fn transition_constraints(&self) -> &Vec> { &self.constraints } fn boundary_constraints(&self) -> BoundaryConstraints { @@ -107,28 +90,28 @@ where fn trace_length(&self) -> usize { self.trace_length } - fn trace_layout(&self) -> (usize, usize) { - (1, 0) + fn trace_layout(&self) -> usize { + 1 } fn pub_inputs(&self) -> &Self::PublicInputs { &self.pub_inputs } fn compute_transition_verifier( &self, - frame: &Frame, - ) -> Vec> { + frame: &Frame, + ) -> Vec> { self.compute_transition_prover(frame) } } -pub fn fibonacci_trace( - initial_values: [FieldElement; 2], +pub fn fibonacci_trace( + initial_values: [FieldElement; 2], trace_length: usize, -) -> TraceTable { - let mut ret: Vec> = vec![]; +) -> TraceTable { + let mut ret: Vec> = vec![]; ret.push(initial_values[0].clone()); ret.push(initial_values[1].clone()); for i in 2..(trace_length) { ret.push(ret[i - 1].clone() + ret[i - 2].clone()); } - TraceTable::from_columns(vec![ret], 1, 1) + TraceTable::from_columns(vec![ret]) } diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs index cb19024ce..f4108cf40 100644 --- a/provers/circle_stark/src/lib.rs +++ b/provers/circle_stark/src/lib.rs @@ -8,3 +8,9 @@ pub mod prover; pub mod table; pub mod trace; pub mod vanishing_poly; +pub mod examples; + + + +#[cfg(test)] +pub mod tests; \ No newline at end of file diff --git a/provers/circle_stark/src/tests/integration.rs b/provers/circle_stark/src/tests/integration.rs new file mode 100644 index 000000000..7b9313097 --- /dev/null +++ b/provers/circle_stark/src/tests/integration.rs @@ -0,0 +1,38 @@ +use lambdaworks_math::field::{ + element::FieldElement, fields::{fft_friendly::stark_252_prime_field::Stark252PrimeField, mersenne31::field::Mersenne31Field}, +}; + +use crate::{ + examples::{ + simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}, + }, + // proof::options::ProofOptions, + // prover::{IsStarkProver, Prover}, +}; + +#[test_log::test] +fn test_prove_fib() { + + type FE = FieldElement; + + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 1024); + + let pub_inputs = FibonacciPublicInputs { + a0: FE::one(), + a1: FE::one(), + }; + + let proof = Prover::>::prove( + &trace, + &pub_inputs, + &proof_options, + StoneProverTranscript::new(&[]), + ); +// .unwrap(); +// assert!(Verifier::>::verify( +// &proof, +// &pub_inputs, +// &proof_options, +// StoneProverTranscript::new(&[]), +// )); +} diff --git a/provers/circle_stark/src/tests/mod.rs b/provers/circle_stark/src/tests/mod.rs new file mode 100644 index 000000000..605134bbc --- /dev/null +++ b/provers/circle_stark/src/tests/mod.rs @@ -0,0 +1 @@ +mod integration; \ No newline at end of file From 12e1d4adfee977c4f52427e4c79a2a5d7c4c3903 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 6 Nov 2024 13:16:22 -0300 Subject: [PATCH 87/93] Prover validate boundary and transition constraints --- math/src/circle/cosets.rs | 2 +- provers/circle_stark/src/frame.rs | 3 +- provers/circle_stark/src/prover.rs | 142 ++++++++++++++---- provers/circle_stark/src/tests/integration.rs | 34 ++--- provers/circle_stark/src/trace.rs | 40 +++-- 5 files changed, 150 insertions(+), 71 deletions(-) diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 957097efb..b9fcdafcc 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -57,7 +57,7 @@ impl Coset { pub fn get_coset_points(coset: &Self) -> Vec> { // g_n the generator of the subgroup of order n. let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); - let size: u8 = 1 << coset.log_2_size; + let size: usize = 1 << coset.log_2_size; core::iter::successors(Some(coset.shift.clone()), move |prev| { Some(prev + &generator_n) }) diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs index 8752420bf..27abc239b 100644 --- a/provers/circle_stark/src/frame.rs +++ b/provers/circle_stark/src/frame.rs @@ -6,8 +6,7 @@ use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field:: /// The collected steps are all the necessary steps for /// all transition costraints over a trace to be evaluated. #[derive(Clone, Debug, PartialEq)] -pub struct Frame -{ +pub struct Frame { steps: Vec>>, } diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs index 4949ac1cf..482a92823 100644 --- a/provers/circle_stark/src/prover.rs +++ b/provers/circle_stark/src/prover.rs @@ -1,52 +1,128 @@ -use crate::config::FriMerkleTree; +use crate::{air::AIR, config::FriMerkleTree, domain::Domain, frame::Frame, trace::{LDETraceTable, TraceTable}}; +use std::marker::PhantomData; use super::config::Commitment; use lambdaworks_math::{ - field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, - circle::polynomial::{interpolate_cfft, evaluate_cfft} + circle::polynomial::{evaluate_cfft, evaluate_point, interpolate_cfft}, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; -const BLOW_UP_FACTOR: usize = 2; +/// A default STARK prover implementing `IsStarkProver`. +pub struct Prover { + phantom: PhantomData, +} + +impl IsStarkProver for Prover {} + +#[derive(Debug)] +pub enum ProvingError { + WrongParameter(String), +} + +pub trait IsStarkProver { + fn prove(trace: &TraceTable, pub_inputs: &A::PublicInputs) { + let air = A::new(trace.n_rows(), pub_inputs); + let domain = Domain::new(&air); + let lde_domain_length = domain.blowup_factor * domain.trace_length; -pub fn prove(trace: Vec>) -> Commitment { - - let lde_domain_size = trace.len() * BLOW_UP_FACTOR; + // For each column, calculate the coefficients of the trace interpolating polynomial. + let mut trace_polys = trace.compute_trace_polys(); - // Returns the coef of the interpolating polinomial of the trace on a natural domain. - let mut trace_poly = interpolate_cfft(trace); + // Evaluate each polynomial in the lde domain. + let lde_trace_evaluations = trace_polys + .iter_mut() + .map(|poly| { + // Padding with zeros the coefficients of the polynomial, so we can evaluate it in the lde domain. + poly.resize(lde_domain_length, FieldElement::zero()); + evaluate_cfft(poly.to_vec()) + }) + .collect::>>>(); - // Padding with zeros the coefficients of the polynomial, so we can evaluate it in the lde domain. - trace_poly.resize(lde_domain_size, FieldElement::zero()); - let lde_eval = evaluate_cfft(trace_poly); + // TODO: Commit on lde trace evaluations. - let tree = FriMerkleTree::::build(&lde_eval).unwrap(); - let commitment = tree.root; + // --------- VALIDATE LDE TRACE EVALUATIONS ------------ + + // Interpolate lde trace evaluations. + let lde_coefficients = lde_trace_evaluations + .iter() + .map(|evals| interpolate_cfft(evals.to_vec())) + .collect::>>>(); - commitment + // Evaluate lde trace interpolating polynomial in trace domain. + for point in domain.trace_coset_points { + // println!("{:?}", evaluate_point(&lde_coefficients[0], &point)); + } + + // Crate a LDE_TRACE with a blow up factor of one, so its the same values as the trace. + let lde_trace = LDETraceTable::new(trace.table.data.clone(), 1, 1); + + // --------- VALIDATE BOUNDARY CONSTRAINTS ------------ + air.boundary_constraints() + .constraints + .iter() + .for_each(|constraint| { + let col = constraint.col; + let step = constraint.step; + let boundary_value = constraint.value.clone(); + + let trace_value = trace.table.get(step, col).clone(); + + if boundary_value.clone() != trace_value { + println!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value); + } else { + println!("Consistent Boundary constraint - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value) + } + }); + + // --------- VALIDATE TRANSITION CONSTRAINTS ----------- + for row_index in 0..lde_trace.table.height - 2 { + let frame = Frame::read_from_lde(&lde_trace, row_index, &air.context().transition_offsets); + let evaluations = air.compute_transition_prover(&frame); + println!("Transition constraints evaluations: {:?}", evaluations); + } + } } +// const BLOW_UP_FACTOR: usize = 2; + +// pub fn prove(trace: Vec>) -> Commitment { + +// let lde_domain_size = trace.len() * BLOW_UP_FACTOR; + +// // Returns the coef of the interpolating polinomial of the trace on a natural domain. +// let mut trace_poly = interpolate_cfft(trace); + +// // Padding with zeros the coefficients of the polynomial, so we can evaluate it in the lde domain. +// trace_poly.resize(lde_domain_size, FieldElement::zero()); +// let lde_eval = evaluate_cfft(trace_poly); + +// let tree = FriMerkleTree::::build(&lde_eval).unwrap(); +// let commitment = tree.root; + +// commitment +// } #[cfg(test)] mod tests { - + use super::*; type FE = FieldElement; - #[test] - fn basic_test() { - let trace = vec![ - FE::from(1), - FE::from(2), - FE::from(3), - FE::from(4), - FE::from(5), - FE::from(6), - FE::from(7), - FE::from(8), - ]; - - let commitmet = prove(trace); - println!("{:?}", commitmet); - } -} \ No newline at end of file + // #[test] + // fn basic_test() { + // let trace = vec![ + // FE::from(1), + // FE::from(2), + // FE::from(3), + // FE::from(4), + // FE::from(5), + // FE::from(6), + // FE::from(7), + // FE::from(8), + // ]; + + // let commitmet = prove(trace); + // println!("{:?}", commitmet); + // } +} diff --git a/provers/circle_stark/src/tests/integration.rs b/provers/circle_stark/src/tests/integration.rs index 7b9313097..d35bda95a 100644 --- a/provers/circle_stark/src/tests/integration.rs +++ b/provers/circle_stark/src/tests/integration.rs @@ -1,38 +1,34 @@ use lambdaworks_math::field::{ - element::FieldElement, fields::{fft_friendly::stark_252_prime_field::Stark252PrimeField, mersenne31::field::Mersenne31Field}, + element::FieldElement, + fields::{ + fft_friendly::stark_252_prime_field::Stark252PrimeField, mersenne31::field::Mersenne31Field, + }, }; use crate::{ - examples::{ - simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}, - }, + examples::simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}, + prover::{IsStarkProver, Prover}, // proof::options::ProofOptions, // prover::{IsStarkProver, Prover}, }; #[test_log::test] fn test_prove_fib() { - type FE = FieldElement; - let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 1024); + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 512); let pub_inputs = FibonacciPublicInputs { a0: FE::one(), a1: FE::one(), }; - let proof = Prover::>::prove( - &trace, - &pub_inputs, - &proof_options, - StoneProverTranscript::new(&[]), - ); -// .unwrap(); -// assert!(Verifier::>::verify( -// &proof, -// &pub_inputs, -// &proof_options, -// StoneProverTranscript::new(&[]), -// )); + let proof = Prover::::prove(&trace, &pub_inputs); + // .unwrap(); + // assert!(Verifier::>::verify( + // &proof, + // &pub_inputs, + // &proof_options, + // StoneProverTranscript::new(&[]), + // )); } diff --git a/provers/circle_stark/src/trace.rs b/provers/circle_stark/src/trace.rs index 57c918ec4..4ebb5535e 100644 --- a/provers/circle_stark/src/trace.rs +++ b/provers/circle_stark/src/trace.rs @@ -6,9 +6,7 @@ use lambdaworks_math::{ polynomial::{evaluate_point, interpolate_cfft}, }, fft::errors::FFTError, - field::{ - element::FieldElement, fields::mersenne31::field::Mersenne31Field, - }, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; #[cfg(feature = "parallel")] use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -91,8 +89,7 @@ impl TraceTable { data } - pub fn compute_trace_polys(&self) -> Vec>> - { + pub fn compute_trace_polys(&self) -> Vec>> { let columns = self.columns(); #[cfg(feature = "parallel")] let iter = columns.par_iter(); @@ -120,7 +117,12 @@ impl TraceTable { /// The row_idx passed as argument may be greater than the max row index by 1. In this case, /// last row of the trace is cloned, and the value is set in that cloned row. Then, the row is /// appended to the end of the trace. - pub fn set_or_extend(&mut self, row_idx: usize, col_idx: usize, value: &FieldElement) { + pub fn set_or_extend( + &mut self, + row_idx: usize, + col_idx: usize, + value: &FieldElement, + ) { debug_assert!(col_idx < self.n_cols()); // NOTE: This is not very nice, but for how the function is being used at the moment, // the passed `row_idx` should never be greater than `self.n_rows() + 1`. This is just @@ -137,15 +139,17 @@ impl TraceTable { } } } -pub struct LDETraceTable -{ +pub struct LDETraceTable { pub(crate) table: Table, pub(crate) blowup_factor: usize, } -impl LDETraceTable -{ - pub fn new(data: Vec>, n_columns: usize, blowup_factor: usize) -> Self { +impl LDETraceTable { + pub fn new( + data: Vec>, + n_columns: usize, + blowup_factor: usize, + ) -> Self { let table = Table::new(data, n_columns); Self { @@ -154,7 +158,10 @@ impl LDETraceTable } } - pub fn from_columns(columns: Vec>>, blowup_factor: usize) -> Self { + pub fn from_columns( + columns: Vec>>, + blowup_factor: usize, + ) -> Self { let table = Table::from_columns(columns); Self { @@ -191,14 +198,13 @@ pub fn get_trace_evaluations( point: &CirclePoint, frame_offsets: &[usize], group_generator: &CirclePoint, -) -> Table -{ +) -> Table { let evaluation_points = frame_offsets .iter() .map(|offset| (group_generator * (*offset as u128)) + point) .collect_vec(); - let evaluations: Vec<_> = evaluation_points + let evaluations: Vec<_> = evaluation_points .iter() .flat_map(|eval_point| { trace_polys @@ -211,7 +217,9 @@ pub fn get_trace_evaluations( Table::new(evaluations, table_width) } -pub fn columns2rows(columns: Vec>>) -> Vec>> { +pub fn columns2rows( + columns: Vec>>, +) -> Vec>> { let num_rows = columns[0].len(); let num_cols = columns.len(); From a6a425e91c5ddb745b6a6ea5668faa3881741cec Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 6 Nov 2024 17:38:48 -0300 Subject: [PATCH 88/93] testing: transition zerofier evaluates zero in trace domain --- .../circle_stark/src/constraints/evaluator.rs | 369 +++++++----------- provers/circle_stark/src/frame.rs | 2 +- provers/circle_stark/src/prover.rs | 27 +- 3 files changed, 170 insertions(+), 228 deletions(-) diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs index beb70cf10..b0b7f88cd 100644 --- a/provers/circle_stark/src/constraints/evaluator.rs +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -1,221 +1,148 @@ -// use super::boundary::BoundaryConstraints; -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// use crate::debug::check_boundary_polys_divisibility; -// use crate::domain::Domain; -// use crate::trace::LDETraceTable; -// use crate::traits::AIR; -// use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain}; -// use itertools::Itertools; -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// use lambdaworks_math::polynomial::Polynomial; -// use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes}; -// #[cfg(feature = "parallel")] -// use rayon::{ -// iter::IndexedParallelIterator, -// prelude::{IntoParallelIterator, ParallelIterator}, -// }; -// #[cfg(feature = "instruments")] -// use std::time::Instant; - -// pub struct ConstraintEvaluator { -// boundary_constraints: BoundaryConstraints, -// } -// impl ConstraintEvaluator { -// pub fn new(air: &A, rap_challenges: &[FieldElement]) -> Self { -// let boundary_constraints = air.boundary_constraints(rap_challenges); - -// Self { -// boundary_constraints, -// } -// } - -// pub(crate) fn evaluate( -// &self, -// air: &A, -// lde_trace: &LDETraceTable, -// domain: &Domain, -// transition_coefficients: &[FieldElement], -// boundary_coefficients: &[FieldElement], -// rap_challenges: &[FieldElement], -// ) -> Vec> -// where -// FieldElement: AsBytes + Send + Sync, -// FieldElement: AsBytes + Send + Sync, -// A: Send + Sync, -// { -// let boundary_constraints = &self.boundary_constraints; -// let number_of_b_constraints = boundary_constraints.constraints.len(); -// let boundary_zerofiers_inverse_evaluations: Vec>> = -// boundary_constraints -// .constraints -// .iter() -// .map(|bc| { -// let point = &domain.trace_primitive_root.pow(bc.step as u64); -// let mut evals = domain -// .lde_roots_of_unity_coset -// .iter() -// .map(|v| v.clone() - point) -// .collect::>>(); -// FieldElement::inplace_batch_inverse(&mut evals).unwrap(); -// evals -// }) -// .collect::>>>(); - -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// let boundary_polys: Vec>> = Vec::new(); - -// #[cfg(feature = "instruments")] -// let timer = Instant::now(); - -// let lde_periodic_columns = air -// .get_periodic_column_polynomials() -// .iter() -// .map(|poly| { -// evaluate_polynomial_on_lde_domain( -// poly, -// domain.blowup_factor, -// domain.interpolation_domain_size, -// &domain.coset_offset, -// ) -// }) -// .collect::>>, FFTError>>() -// .unwrap(); - -// #[cfg(feature = "instruments")] -// println!( -// " Evaluating periodic columns on lde: {:#?}", -// timer.elapsed() -// ); - -// #[cfg(feature = "instruments")] -// let timer = Instant::now(); - -// let boundary_polys_evaluations = boundary_constraints -// .constraints -// .iter() -// .map(|constraint| { -// if constraint.is_aux { -// (0..lde_trace.num_rows()) -// .map(|row| { -// let v = lde_trace.get_aux(row, constraint.col); -// v - &constraint.value -// }) -// .collect_vec() -// } else { -// (0..lde_trace.num_rows()) -// .map(|row| { -// let v = lde_trace.get_main(row, constraint.col); -// v - &constraint.value -// }) -// .collect_vec() -// } -// }) -// .collect_vec(); - -// #[cfg(feature = "instruments")] -// println!(" Created boundary polynomials: {:#?}", timer.elapsed()); -// #[cfg(feature = "instruments")] -// let timer = Instant::now(); - -// #[cfg(feature = "parallel")] -// let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter(); -// #[cfg(not(feature = "parallel"))] -// let boundary_eval_iter = 0..domain.lde_roots_of_unity_coset.len(); - -// let boundary_evaluation: Vec<_> = boundary_eval_iter -// .map(|domain_index| { -// (0..number_of_b_constraints) -// .zip(boundary_coefficients) -// .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { -// acc + &boundary_zerofiers_inverse_evaluations[constraint_index] -// [domain_index] -// * beta -// * &boundary_polys_evaluations[constraint_index][domain_index] -// }) -// }) -// .collect(); - -// #[cfg(feature = "instruments")] -// println!( -// " Evaluated boundary polynomials on LDE: {:#?}", -// timer.elapsed() -// ); - -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// let boundary_zerofiers = Vec::new(); - -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// check_boundary_polys_divisibility(boundary_polys, boundary_zerofiers); - -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// let mut transition_evaluations = Vec::new(); - -// #[cfg(feature = "instruments")] -// let timer = Instant::now(); -// let zerofiers_evals = air.transition_zerofier_evaluations(domain); -// #[cfg(feature = "instruments")] -// println!( -// " Evaluated transition zerofiers: {:#?}", -// timer.elapsed() -// ); - -// // Iterate over all LDE domain and compute the part of the composition polynomial -// // related to the transition constraints and add it to the already computed part of the -// // boundary constraints. - -// #[cfg(feature = "instruments")] -// let timer = Instant::now(); -// let evaluations_t_iter = 0..domain.lde_roots_of_unity_coset.len(); - -// #[cfg(feature = "parallel")] -// let boundary_evaluation = boundary_evaluation.into_par_iter(); -// #[cfg(feature = "parallel")] -// let evaluations_t_iter = evaluations_t_iter.into_par_iter(); - -// let evaluations_t = evaluations_t_iter -// .zip(boundary_evaluation) -// .map(|(i, boundary)| { -// let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); - -// let periodic_values: Vec<_> = lde_periodic_columns -// .iter() -// .map(|col| col[i].clone()) -// .collect(); - -// // Compute all the transition constraints at this point of the LDE domain. -// let evaluations_transition = -// air.compute_transition_prover(&frame, &periodic_values, rap_challenges); - -// #[cfg(all(debug_assertions, not(feature = "parallel")))] -// transition_evaluations.push(evaluations_transition.clone()); - -// // Add each term of the transition constraints to the composition polynomial, including the zerofier, -// // the challenge and the exemption polynomial if it is necessary. -// let acc_transition = itertools::izip!( -// evaluations_transition, -// &zerofiers_evals, -// transition_coefficients -// ) -// .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { -// // Zerofier evaluations are cyclical, so we only calculate one cycle. -// // This means that here we have to wrap around -// // Ex: Suppose the full zerofier vector is Z = [1,2,3,1,2,3] -// // we will instead have calculated Z' = [1,2,3] -// // Now if you need Z[4] this is equal to Z'[1] -// let wrapped_idx = i % zerof_eval.len(); -// acc + &zerof_eval[wrapped_idx] * eval * beta -// }); - -// acc_transition + boundary -// }) -// .collect(); - -// #[cfg(feature = "instruments")] -// println!( -// " Evaluated transitions and accumulated results: {:#?}", -// timer.elapsed() -// ); - -// evaluations_t -// } -// } +use super::boundary::BoundaryConstraints; +use crate::air::AIR; +use std::marker::PhantomData; + +use crate::{domain::Domain, frame::Frame, trace::LDETraceTable}; +use itertools::Itertools; +use lambdaworks_math::circle::polynomial::{evaluate_point, interpolate_cfft}; +use lambdaworks_math::field::element::FieldElement; +use lambdaworks_math::field::fields::mersenne31::field::Mersenne31Field; + +pub struct ConstraintEvaluator { + boundary_constraints: BoundaryConstraints, + phantom: PhantomData, +} +impl ConstraintEvaluator { + pub fn new(air: &A) -> Self { + let boundary_constraints = air.boundary_constraints(); + + Self { + boundary_constraints, + phantom: PhantomData, + } + } + + pub(crate) fn evaluate( + &self, + air: &A, + lde_trace: &LDETraceTable, + domain: &Domain, + transition_coefficients: &[FieldElement], + boundary_coefficients: &[FieldElement], + ) -> Vec> { + let boundary_constraints = &self.boundary_constraints; + let number_of_b_constraints = boundary_constraints.constraints.len(); + + let boundary_zerofiers_inverse_evaluations: Vec>> = + boundary_constraints + .constraints + .iter() + .map(|bc| { + let vanish_point = &domain.trace_coset_points[bc.step]; + let mut evals = domain + .lde_coset_points + .iter() + .map(|eval_point| { + (eval_point + vanish_point.clone().conjugate()).x + - FieldElement::::one() + }) + .collect::>>(); + FieldElement::inplace_batch_inverse(&mut evals).unwrap(); + evals + }) + .collect::>>>(); + + let boundary_polys_evaluations = boundary_constraints + .constraints + .iter() + .map(|constraint| { + (0..lde_trace.num_rows()) + .map(|row| { + let v = lde_trace.table.get(row, constraint.col); + v - &constraint.value + }) + .collect_vec() + }) + .collect_vec(); + + // --------------- BEGIN TESTING ---------------------------- + // Interpolate lde trace evaluations. + // let boundary_poly_coefficients = boundary_polys_evaluations + // .iter() + // .map(|evals| interpolate_cfft(evals.to_vec())) + // .collect::>>>(); + + // Evaluate lde trace interpolating polynomial in trace domain. + // for point in &domain.trace_coset_points { + // println!( + // "{:?}", + // evaluate_point(&boundary_poly_coefficients[0], &point) + // ); + // } + // --------------- END TESTING ---------------------------- + + let boundary_eval_iter = 0..domain.lde_coset_points.len(); + + let boundary_evaluation: Vec<_> = boundary_eval_iter + .map(|domain_index| { + (0..number_of_b_constraints) + .zip(boundary_coefficients) + .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { + acc + &boundary_zerofiers_inverse_evaluations[constraint_index] + [domain_index] + * beta + * &boundary_polys_evaluations[constraint_index][domain_index] + }) + }) + .collect(); + + // Iterate over all LDE domain and compute the part of the composition polynomial + // related to the transition constraints and add it to the already computed part of the + // boundary constraints. + + let zerofiers_evals = air.transition_zerofier_evaluations(domain); + + // --------------- BEGIN TESTING ---------------------------- + // Interpolate lde trace evaluations. + let zerofier_poly_coefficients = zerofiers_evals + .iter() + .map(|evals| interpolate_cfft(evals.to_vec())) + .collect::>>>(); + + // Evaluate lde trace interpolating polynomial in trace domain. + for point in &domain.trace_coset_points { + println!( + "{:?}", + evaluate_point(&zerofier_poly_coefficients[0], &point) + ); + } + // --------------- END TESTING ---------------------------- + + let evaluations_t_iter = 0..domain.lde_coset_points.len(); + + let evaluations_t = evaluations_t_iter + .zip(boundary_evaluation) + .map(|(i, boundary)| { + let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); + + // Compute all the transition constraints at this point of the LDE domain. + let evaluations_transition = air.compute_transition_prover(&frame); + + // Add each term of the transition constraints to the composition polynomial, including the zerofier, + // the challenge and the exemption polynomial if it is necessary. + let acc_transition = itertools::izip!( + evaluations_transition, + &zerofiers_evals, + transition_coefficients + ) + .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { + acc + &zerof_eval[i] * eval * beta + }); + + acc_transition + boundary + }) + .collect(); + + evaluations_t + } +} diff --git a/provers/circle_stark/src/frame.rs b/provers/circle_stark/src/frame.rs index 27abc239b..315190f5f 100644 --- a/provers/circle_stark/src/frame.rs +++ b/provers/circle_stark/src/frame.rs @@ -24,7 +24,7 @@ impl Frame { let lde_steps = offsets .iter() - .map(|offset| lde_trace.get_row(row + offset).to_vec()) + .map(|offset| lde_trace.get_row((row + offset) % num_rows).to_vec()) .collect_vec(); Frame::new(lde_steps) diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs index 482a92823..660bcf3aa 100644 --- a/provers/circle_stark/src/prover.rs +++ b/provers/circle_stark/src/prover.rs @@ -1,4 +1,4 @@ -use crate::{air::AIR, config::FriMerkleTree, domain::Domain, frame::Frame, trace::{LDETraceTable, TraceTable}}; +use crate::{air::AIR, config::FriMerkleTree, constraints::evaluator::ConstraintEvaluator, domain::Domain, frame::Frame, trace::{LDETraceTable, TraceTable}}; use std::marker::PhantomData; use super::config::Commitment; @@ -49,12 +49,12 @@ pub trait IsStarkProver { .collect::>>>(); // Evaluate lde trace interpolating polynomial in trace domain. - for point in domain.trace_coset_points { + for point in &domain.trace_coset_points { // println!("{:?}", evaluate_point(&lde_coefficients[0], &point)); } // Crate a LDE_TRACE with a blow up factor of one, so its the same values as the trace. - let lde_trace = LDETraceTable::new(trace.table.data.clone(), 1, 1); + let not_lde_trace = LDETraceTable::new(trace.table.data.clone(), 1, 1); // --------- VALIDATE BOUNDARY CONSTRAINTS ------------ air.boundary_constraints() @@ -75,11 +75,26 @@ pub trait IsStarkProver { }); // --------- VALIDATE TRANSITION CONSTRAINTS ----------- - for row_index in 0..lde_trace.table.height - 2 { - let frame = Frame::read_from_lde(&lde_trace, row_index, &air.context().transition_offsets); + for row_index in 0..not_lde_trace.table.height - 2 { + let frame = Frame::read_from_lde(¬_lde_trace, row_index, &air.context().transition_offsets); let evaluations = air.compute_transition_prover(&frame); - println!("Transition constraints evaluations: {:?}", evaluations); + // println!("Transition constraints evaluations: {:?}", evaluations); } + + + let transition_coefficients: Vec> = vec![FieldElement::::one(); air.context().num_transition_constraints()]; + let boundary_coefficients: Vec> = vec![FieldElement::::one(); air.boundary_constraints().constraints.len()]; + + // Compute the evaluations of the composition polynomial on the LDE domain. + let lde_trace = LDETraceTable::from_columns(lde_trace_evaluations, domain.blowup_factor); + let evaluator = ConstraintEvaluator::new(&air); + let constraint_evaluations = evaluator.evaluate( + &air, + &lde_trace, + &domain, + &transition_coefficients, + &boundary_coefficients, + ); } } From 6fb8311c1e0e2176570dd1f5e989a92551ac1292 Mon Sep 17 00:00:00 2001 From: Nicole Date: Fri, 8 Nov 2024 17:59:02 -0300 Subject: [PATCH 89/93] Add line and interpolant functions --- provers/circle_stark/src/air.rs | 1 + .../circle_stark/src/constraints/evaluator.rs | 48 ++++++++++++++++--- .../src/constraints/transition.rs | 42 ++++++++++++---- provers/circle_stark/src/prover.rs | 4 +- provers/circle_stark/src/tests/integration.rs | 2 +- 5 files changed, 79 insertions(+), 18 deletions(-) diff --git a/provers/circle_stark/src/air.rs b/provers/circle_stark/src/air.rs index 201194f4f..a59f9b9ba 100644 --- a/provers/circle_stark/src/air.rs +++ b/provers/circle_stark/src/air.rs @@ -87,6 +87,7 @@ pub trait AIR { .entry(zerofier_group_key) .or_insert_with(|| c.zerofier_evaluations_on_extended_domain(domain)); let zerofier_evaluations = zerofier_groups.get(&zerofier_group_key).unwrap(); + // println!("ZEROFIER_EVALUATIONS: {:?}", zerofier_evaluations); evals[c.constraint_idx()] = zerofier_evaluations.clone(); }); evals diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs index b0b7f88cd..2f50d388c 100644 --- a/provers/circle_stark/src/constraints/evaluator.rs +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{domain::Domain, frame::Frame, trace::LDETraceTable}; use itertools::Itertools; +use lambdaworks_math::circle::point::CirclePoint; use lambdaworks_math::circle::polynomial::{evaluate_point, interpolate_cfft}; use lambdaworks_math::field::element::FieldElement; use lambdaworks_math::field::fields::mersenne31::field::Mersenne31Field; @@ -12,6 +13,38 @@ pub struct ConstraintEvaluator { boundary_constraints: BoundaryConstraints, phantom: PhantomData, } + +// See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). +// https://github.com/ethereum/research/blob/master/circlestark/line_functions.py#L10 +pub fn line( + point: &CirclePoint, + vanish_point_1: &CirclePoint, + vanish_point_2: &CirclePoint, +) -> FieldElement { + (vanish_point_1.y - vanish_point_2.y) * point.x + + (vanish_point_2.x - vanish_point_1.x) * point.y + + (vanish_point_1.x * vanish_point_2.y - vanish_point_1.y * vanish_point_2.x) +} + +// See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). +// https://github.com/ethereum/research/blob/master/circlestark/line_functions.py#L16 +// Evaluates the polybomial I at eval_point. I is the polynomial such that I(point_1) = value_1 and +// I(point_2) = value_2. +pub fn interpolant( + point_1: &CirclePoint, + point_2: &CirclePoint, + value_1: FieldElement, + value_2: FieldElement, + eval_point: &CirclePoint, +) -> FieldElement { + let dx = point_2.x - point_1.x; + let dy = point_2.y - point_1.y; + // CHECK: can dx^2 + dy^2 = 0 even if dx!=0 and dy!=0 ? (using that they are FE of Mersenne31). + let invdist = (dx * dx + dy * dy).inv().unwrap(); + let dot = (eval_point.x - point_1.x) * dx + (eval_point.y - point_1.y) * dy; + value_1 + (value_2 - value_1) * dot * invdist +} + impl ConstraintEvaluator { pub fn new(air: &A) -> Self { let boundary_constraints = air.boundary_constraints(); @@ -36,16 +69,16 @@ impl ConstraintEvaluator { let boundary_zerofiers_inverse_evaluations: Vec>> = boundary_constraints .constraints - .iter() - .map(|bc| { - let vanish_point = &domain.trace_coset_points[bc.step]; + .chunks(2) + .map(|chunk| { + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; + let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; let mut evals = domain .lde_coset_points .iter() - .map(|eval_point| { - (eval_point + vanish_point.clone().conjugate()).x - - FieldElement::::one() - }) + .map(|eval_point| line(eval_point, first_vanish_point, second_vanish_point)) .collect::>>(); FieldElement::inplace_batch_inverse(&mut evals).unwrap(); evals @@ -110,6 +143,7 @@ impl ConstraintEvaluator { .collect::>>>(); // Evaluate lde trace interpolating polynomial in trace domain. + // This should print all zeroes except in the end exceptions points. for point in &domain.trace_coset_points { println!( "{:?}", diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index 25d9da946..50004bfd0 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -1,12 +1,12 @@ use crate::domain::Domain; use crate::frame::Frame; use lambdaworks_math::circle::point::CirclePoint; +use lambdaworks_math::circle::polynomial::{evaluate_point, interpolate_cfft}; use lambdaworks_math::field::element::FieldElement; use lambdaworks_math::field::fields::mersenne31::field::Mersenne31Field; /// TransitionConstraint represents the behaviour that a transition constraint /// over the computation that wants to be proven must comply with. -pub trait TransitionConstraint -{ +pub trait TransitionConstraint { /// The degree of the constraint interpreting it as a multivariate polynomial. fn degree(&self) -> usize; @@ -82,12 +82,20 @@ pub trait TransitionConstraint return one; } let period = self.period(); + let double_group_generator = CirclePoint::::get_generator_of_subgroup( + trace_length.trailing_zeros() + 1, + ); // This accumulates evaluations of the point at the zerofier at all the offsets positions. (1..=self.end_exemptions()) - // FIXME: I think this is wrong because exemption should be and element of the stndard coset instead of the group. (-nicole) - .map(|exemption| trace_group_generator * ((trace_length - exemption * period) as u128)) + .map(|exemption| { + &double_group_generator + + (trace_group_generator * ((trace_length - exemption * period) as u128)) + }) .fold(one.clone(), |acc, vanishing_point| { - acc * ((eval_point + vanishing_point.conjugate()).x - &one) + // acc * ((eval_point + vanishing_point.conjugate()).x - &one) + + let h = eval_point + vanishing_point.conjugate(); + acc * (h.y / &one + h.x) }) } @@ -95,7 +103,10 @@ pub trait TransitionConstraint /// TODO: See if we can evaluate using cfft. /// TODO: See if we can optimize computing only some evaluations and cycle them as in regular stark. #[allow(unstable_name_collisions)] - fn zerofier_evaluations_on_extended_domain(&self, domain: &Domain) -> Vec> { + fn zerofier_evaluations_on_extended_domain( + &self, + domain: &Domain, + ) -> Vec> { let blowup_factor = domain.blowup_factor; let trace_length = domain.trace_length; let trace_log_2_size = trace_length.trailing_zeros(); @@ -119,14 +130,29 @@ pub trait TransitionConstraint x }) .collect(); - FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); + // FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); - let end_exemptions_evaluations: Vec<_> = lde_points + let mut end_exemptions_evaluations: Vec<_> = lde_points .iter() .map(|point| { self.evaluate_end_exemptions_poly(point, trace_group_generator, trace_length) }) .collect(); + FieldElement::inplace_batch_inverse(&mut end_exemptions_evaluations).unwrap(); + + // // --------------- BEGIN TESTING ---------------------------- + // // Interpolate lde trace evaluations. + // let end_exemptions_coeff = interpolate_cfft(end_exemptions_evaluations.clone()); + + // // Evaluate lde trace interpolating polynomial in trace domain. + // // This should print zeroes only in the end exceptions points. + // for point in &domain.trace_coset_points { + // println!( + // "EXEMPTIONS POLYS EVALUATED ON TRACE DOMAIN {:?}", + // evaluate_point(&end_exemptions_coeff, &point) + // ); + // } + // // --------------- END TESTING ---------------------------- std::iter::zip(zerofier_evaluations, end_exemptions_evaluations) .map(|(eval, exemptions_eval)| eval * exemptions_eval) diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs index 660bcf3aa..3299685ae 100644 --- a/provers/circle_stark/src/prover.rs +++ b/provers/circle_stark/src/prover.rs @@ -68,9 +68,9 @@ pub trait IsStarkProver { let trace_value = trace.table.get(step, col).clone(); if boundary_value.clone() != trace_value { - println!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value); + // println!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value); } else { - println!("Consistent Boundary constraint - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value) + // println!("Consistent Boundary constraint - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value) } }); diff --git a/provers/circle_stark/src/tests/integration.rs b/provers/circle_stark/src/tests/integration.rs index d35bda95a..2da16ace1 100644 --- a/provers/circle_stark/src/tests/integration.rs +++ b/provers/circle_stark/src/tests/integration.rs @@ -16,7 +16,7 @@ use crate::{ fn test_prove_fib() { type FE = FieldElement; - let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 512); + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 16); let pub_inputs = FibonacciPublicInputs { a0: FE::one(), From 27cadc358c9926a1f0966ed7fbe2c2e8ae1b19ca Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 11 Nov 2024 10:32:04 -0300 Subject: [PATCH 90/93] change boundary_polys_evaluations --- .../circle_stark/src/constraints/evaluator.rs | 66 ++++++++++++++----- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs index 2f50d388c..de0348171 100644 --- a/provers/circle_stark/src/constraints/evaluator.rs +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -17,13 +17,14 @@ pub struct ConstraintEvaluator { // See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). // https://github.com/ethereum/research/blob/master/circlestark/line_functions.py#L10 pub fn line( - point: &CirclePoint, - vanish_point_1: &CirclePoint, - vanish_point_2: &CirclePoint, + eval_point: &CirclePoint, + first_vanish_point: &CirclePoint, + second_vanish_point: &CirclePoint, ) -> FieldElement { - (vanish_point_1.y - vanish_point_2.y) * point.x - + (vanish_point_2.x - vanish_point_1.x) * point.y - + (vanish_point_1.x * vanish_point_2.y - vanish_point_1.y * vanish_point_2.x) + (first_vanish_point.y - second_vanish_point.y) * eval_point.x + + (second_vanish_point.x - first_vanish_point.x) * eval_point.y + + (first_vanish_point.x * second_vanish_point.y + - first_vanish_point.y * second_vanish_point.x) } // See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). @@ -85,18 +86,47 @@ impl ConstraintEvaluator { }) .collect::>>>(); - let boundary_polys_evaluations = boundary_constraints - .constraints - .iter() - .map(|constraint| { - (0..lde_trace.num_rows()) - .map(|row| { - let v = lde_trace.table.get(row, constraint.col); - v - &constraint.value - }) - .collect_vec() - }) - .collect_vec(); + let boundary_polys_evaluations: Vec>> = + boundary_constraints + .constraints + .chunks(2) + .map(|chunk| { + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; + let first_value = first_constraint.value; + let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; + let second_value = second_constraint.value; + let mut evals = domain + .lde_coset_points + .iter() + .map(|eval_point| { + interpolant( + first_vanish_point, + second_vanish_point, + first_value, + second_value, + eval_point, + ) + }) + .collect::>>(); + FieldElement::inplace_batch_inverse(&mut evals).unwrap(); + evals + }) + .collect::>>>(); + + // let boundary_polys_evaluations = boundary_constraints + // .constraints + // .iter() + // .map(|constraint| { + // (0..lde_trace.num_rows()) + // .map(|row| { + // let v = lde_trace.table.get(row, constraint.col); + // v - &constraint.value + // }) + // .collect_vec() + // }) + // .collect_vec(); // --------------- BEGIN TESTING ---------------------------- // Interpolate lde trace evaluations. From 4044f2e7cd464a6586df0916819c2528311e725d Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Mon, 11 Nov 2024 14:58:15 -0300 Subject: [PATCH 91/93] fix zerofier --- .../circle_stark/src/constraints/evaluator.rs | 98 ++++++++++++----- .../src/constraints/transition.rs | 102 +++++++++--------- .../src/examples/simple_fibonacci.rs | 2 +- provers/circle_stark/src/prover.rs | 2 +- provers/circle_stark/src/tests/integration.rs | 2 +- 5 files changed, 127 insertions(+), 79 deletions(-) diff --git a/provers/circle_stark/src/constraints/evaluator.rs b/provers/circle_stark/src/constraints/evaluator.rs index de0348171..956071461 100644 --- a/provers/circle_stark/src/constraints/evaluator.rs +++ b/provers/circle_stark/src/constraints/evaluator.rs @@ -97,11 +97,12 @@ impl ConstraintEvaluator { let first_value = first_constraint.value; let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; let second_value = second_constraint.value; - let mut evals = domain + let evals = domain .lde_coset_points .iter() - .map(|eval_point| { - interpolant( + .zip(&lde_trace.table.data) + .map(|(eval_point, lde_eval)| { + lde_eval - interpolant( first_vanish_point, second_vanish_point, first_value, @@ -110,7 +111,6 @@ impl ConstraintEvaluator { ) }) .collect::>>(); - FieldElement::inplace_batch_inverse(&mut evals).unwrap(); evals }) .collect::>>>(); @@ -130,18 +130,35 @@ impl ConstraintEvaluator { // --------------- BEGIN TESTING ---------------------------- // Interpolate lde trace evaluations. - // let boundary_poly_coefficients = boundary_polys_evaluations - // .iter() - // .map(|evals| interpolate_cfft(evals.to_vec())) - // .collect::>>>(); + let l_poly_coefficients = boundary_zerofiers_inverse_evaluations + .iter() + .map(|evals| { + let mut inverse_evals = evals.clone(); + FieldElement::inplace_batch_inverse(&mut inverse_evals).unwrap(); + interpolate_cfft(inverse_evals.to_vec()) + }) + .collect::>>>(); + + let fi_poly_coefficients = boundary_polys_evaluations + .iter() + .map(|evals| { + interpolate_cfft(evals.to_vec()) + }) + .collect::>>>(); // Evaluate lde trace interpolating polynomial in trace domain. - // for point in &domain.trace_coset_points { - // println!( - // "{:?}", - // evaluate_point(&boundary_poly_coefficients[0], &point) - // ); - // } + for point in &domain.trace_coset_points { + println!("-----------------------"); + println!( + "L evaluation: {:?}", + evaluate_point(&l_poly_coefficients[0], &point) + ); + println!( + "F-I evaluation: {:?}", + evaluate_point(&fi_poly_coefficients[0], &point) + ); + } + // --------------- END TESTING ---------------------------- let boundary_eval_iter = 0..domain.lde_coset_points.len(); @@ -149,6 +166,7 @@ impl ConstraintEvaluator { let boundary_evaluation: Vec<_> = boundary_eval_iter .map(|domain_index| { (0..number_of_b_constraints) + .step_by(2) .zip(boundary_coefficients) .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { acc + &boundary_zerofiers_inverse_evaluations[constraint_index] @@ -167,19 +185,19 @@ impl ConstraintEvaluator { // --------------- BEGIN TESTING ---------------------------- // Interpolate lde trace evaluations. - let zerofier_poly_coefficients = zerofiers_evals - .iter() - .map(|evals| interpolate_cfft(evals.to_vec())) - .collect::>>>(); + // let zerofier_poly_coefficients = zerofiers_evals + // .iter() + // .map(|evals| interpolate_cfft(evals.to_vec())) + // .collect::>>>(); - // Evaluate lde trace interpolating polynomial in trace domain. - // This should print all zeroes except in the end exceptions points. - for point in &domain.trace_coset_points { - println!( - "{:?}", - evaluate_point(&zerofier_poly_coefficients[0], &point) - ); - } + // // Evaluate lde trace interpolating polynomial in trace domain. + // // This should print all zeroes except in the end exceptions points. + // for point in &domain.trace_coset_points { + // println!( + // "{:?}", + // evaluate_point(&zerofier_poly_coefficients[0], &point) + // ); + // } // --------------- END TESTING ---------------------------- let evaluations_t_iter = 0..domain.lde_coset_points.len(); @@ -210,3 +228,31 @@ impl ConstraintEvaluator { evaluations_t } } + +#[cfg(test)] +mod tests { + use super::*; + use lambdaworks_math::circle::cosets::Coset; + + type FE = FieldElement; + + #[test] + fn line_vanishes_in_vanishing_points(){ + let first_vanish_point = CirclePoint::GENERATOR * 3; + let second_vanish_point = CirclePoint::GENERATOR * 5; + assert_eq!(line(&first_vanish_point, &first_vanish_point, &second_vanish_point), FE::zero()); + assert_eq!(line(&second_vanish_point, &first_vanish_point, &second_vanish_point), FE::zero()); + } + + #[test] + fn interpolant_takes_the_corresponding_values(){ + let point_1 = CirclePoint::GENERATOR * 4; + let point_2 = CirclePoint::GENERATOR * 9; + let value_1 = FE::from(5); + let value_2 = FE::from(3); + assert_eq!(interpolant(&point_1, &point_2, value_1, value_2, &point_1), value_1); + assert_eq!(interpolant(&point_1, &point_2, value_1, value_2, &point_2), value_2); + + } + +} diff --git a/provers/circle_stark/src/constraints/transition.rs b/provers/circle_stark/src/constraints/transition.rs index 50004bfd0..08ba957ff 100644 --- a/provers/circle_stark/src/constraints/transition.rs +++ b/provers/circle_stark/src/constraints/transition.rs @@ -1,5 +1,6 @@ use crate::domain::Domain; use crate::frame::Frame; +use crate::constraints::evaluator::line; use lambdaworks_math::circle::point::CirclePoint; use lambdaworks_math::circle::polynomial::{evaluate_point, interpolate_cfft}; use lambdaworks_math::field::element::FieldElement; @@ -77,26 +78,28 @@ pub trait TransitionConstraint { trace_group_generator: &CirclePoint, trace_length: usize, ) -> FieldElement { + let one = FieldElement::::one(); + if self.end_exemptions() == 0 { return one; } - let period = self.period(); + let double_group_generator = CirclePoint::::get_generator_of_subgroup( trace_length.trailing_zeros() + 1, ); - // This accumulates evaluations of the point at the zerofier at all the offsets positions. + (1..=self.end_exemptions()) + .step_by(2) .map(|exemption| { - &double_group_generator - + (trace_group_generator * ((trace_length - exemption * period) as u128)) - }) - .fold(one.clone(), |acc, vanishing_point| { - // acc * ((eval_point + vanishing_point.conjugate()).x - &one) + println!("EXEMPTION: {:?}", exemption); + let first_vanish_point = &double_group_generator + (trace_group_generator * ((trace_length - exemption) as u128)); + + let second_vanish_point = &double_group_generator + (trace_group_generator * ((trace_length - (exemption + 1)) as u128)); - let h = eval_point + vanishing_point.conjugate(); - acc * (h.y / &one + h.x) + line(eval_point, &first_vanish_point, &second_vanish_point) }) + .fold(one, |acc, eval| acc * eval) } /// Compute evaluations of the constraints zerofier over a LDE domain. @@ -130,57 +133,56 @@ pub trait TransitionConstraint { x }) .collect(); - // FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); + FieldElement::inplace_batch_inverse(&mut zerofier_evaluations).unwrap(); - let mut end_exemptions_evaluations: Vec<_> = lde_points + let end_exemptions_evaluations: Vec<_> = lde_points .iter() .map(|point| { self.evaluate_end_exemptions_poly(point, trace_group_generator, trace_length) }) .collect(); - FieldElement::inplace_batch_inverse(&mut end_exemptions_evaluations).unwrap(); - - // // --------------- BEGIN TESTING ---------------------------- - // // Interpolate lde trace evaluations. - // let end_exemptions_coeff = interpolate_cfft(end_exemptions_evaluations.clone()); - - // // Evaluate lde trace interpolating polynomial in trace domain. - // // This should print zeroes only in the end exceptions points. - // for point in &domain.trace_coset_points { - // println!( - // "EXEMPTIONS POLYS EVALUATED ON TRACE DOMAIN {:?}", - // evaluate_point(&end_exemptions_coeff, &point) - // ); - // } - // // --------------- END TESTING ---------------------------- + + // --------------- BEGIN TESTING ---------------------------- + // Interpolate lde trace evaluations. + let end_exemptions_coeff = interpolate_cfft(end_exemptions_evaluations.clone()); + + // Evaluate lde trace interpolating polynomial in trace domain. + // This should print zeroes only in the end exceptions points. + for point in &domain.trace_coset_points { + println!( + "EXEMPTIONS POLYS EVALUATED ON TRACE DOMAIN {:?}", + evaluate_point(&end_exemptions_coeff, &point) + ); + } + // --------------- END TESTING ---------------------------- std::iter::zip(zerofier_evaluations, end_exemptions_evaluations) .map(|(eval, exemptions_eval)| eval * exemptions_eval) .collect() } - /// Returns the evaluation of the zerofier corresponding to this constraint in some point - /// `eval_point`, (which is in the circle over the extension field). - #[allow(unstable_name_collisions)] - fn evaluate_zerofier( - &self, - eval_point: &CirclePoint, - trace_group_generator: &CirclePoint, - trace_length: usize, - ) -> FieldElement { - // if let Some(exemptions_period) = self.exemptions_period() { - - // } else { - - let end_exemptions_evaluation = - self.evaluate_end_exemptions_poly(eval_point, trace_group_generator, trace_length); - - let trace_log_2_size = trace_length.trailing_zeros(); - let mut x = eval_point.x.clone(); - for _ in 1..trace_log_2_size { - x = x.square().double() - FieldElement::::one(); - } - - x.inv().unwrap() * end_exemptions_evaluation - } + ///// Returns the evaluation of the zerofier corresponding to this constraint in some point + ///// `eval_point`, (which is in the circle over the extension field). + // #[allow(unstable_name_collisions)] + // fn evaluate_zerofier( + // &self, + // eval_point: &CirclePoint, + // trace_group_generator: &CirclePoint, + // trace_length: usize, + // ) -> FieldElement { + // // if let Some(exemptions_period) = self.exemptions_period() { + + // // } else { + + // let end_exemptions_evaluation = + // self.evaluate_end_exemptions_poly(eval_point, trace_group_generator, trace_length); + + // let trace_log_2_size = trace_length.trailing_zeros(); + // let mut x = eval_point.x.clone(); + // for _ in 1..trace_log_2_size { + // x = x.square().double() - FieldElement::::one(); + // } + + // x.inv().unwrap() * end_exemptions_evaluation + // } } diff --git a/provers/circle_stark/src/examples/simple_fibonacci.rs b/provers/circle_stark/src/examples/simple_fibonacci.rs index ae2d35fde..927bf13a4 100644 --- a/provers/circle_stark/src/examples/simple_fibonacci.rs +++ b/provers/circle_stark/src/examples/simple_fibonacci.rs @@ -26,7 +26,7 @@ impl TransitionConstraint for FibConstraint { 0 } fn end_exemptions(&self) -> usize { - 2 + 4 } fn evaluate( &self, diff --git a/provers/circle_stark/src/prover.rs b/provers/circle_stark/src/prover.rs index 3299685ae..6032aaab2 100644 --- a/provers/circle_stark/src/prover.rs +++ b/provers/circle_stark/src/prover.rs @@ -83,7 +83,7 @@ pub trait IsStarkProver { let transition_coefficients: Vec> = vec![FieldElement::::one(); air.context().num_transition_constraints()]; - let boundary_coefficients: Vec> = vec![FieldElement::::one(); air.boundary_constraints().constraints.len()]; + let boundary_coefficients: Vec> = vec![FieldElement::::one(); air.boundary_constraints().constraints.len() / 2]; // Compute the evaluations of the composition polynomial on the LDE domain. let lde_trace = LDETraceTable::from_columns(lde_trace_evaluations, domain.blowup_factor); diff --git a/provers/circle_stark/src/tests/integration.rs b/provers/circle_stark/src/tests/integration.rs index 2da16ace1..9fe78a89b 100644 --- a/provers/circle_stark/src/tests/integration.rs +++ b/provers/circle_stark/src/tests/integration.rs @@ -16,7 +16,7 @@ use crate::{ fn test_prove_fib() { type FE = FieldElement; - let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 16); + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 32); let pub_inputs = FibonacciPublicInputs { a0: FE::one(), From 4cb5b162bf9a1a79c91d4fc9831582318e62541e Mon Sep 17 00:00:00 2001 From: Nicole Date: Mon, 11 Nov 2024 18:55:02 -0300 Subject: [PATCH 92/93] add utils to constraints. Add composition_polynomial.rs --- .../src/composition_polynomial.rs | 180 ++++++++++++++++++ .../circle_stark/src/constraints/boundary.rs | 70 +++---- provers/circle_stark/src/constraints/mod.rs | 1 + provers/circle_stark/src/constraints/utils.rs | 36 ++++ provers/circle_stark/src/lib.rs | 7 +- 5 files changed, 257 insertions(+), 37 deletions(-) create mode 100644 provers/circle_stark/src/composition_polynomial.rs create mode 100644 provers/circle_stark/src/constraints/utils.rs diff --git a/provers/circle_stark/src/composition_polynomial.rs b/provers/circle_stark/src/composition_polynomial.rs new file mode 100644 index 000000000..5586f23b2 --- /dev/null +++ b/provers/circle_stark/src/composition_polynomial.rs @@ -0,0 +1,180 @@ +use crate::air::AIR; +use crate::{ + constraints::boundary::BoundaryConstraints, domain::Domain, frame::Frame, trace::LDETraceTable, +}; +use itertools::Itertools; +use lambdaworks_math::{ + circle::{ + point::CirclePoint, + polynomial::{evaluate_point, interpolate_cfft}, + }, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; + +pub(crate) fn evaluate_cp( + air: &A, + lde_trace: &LDETraceTable, + domain: &Domain, + transition_coefficients: &[FieldElement], + boundary_coefficients: &[FieldElement], +) -> Vec> { + let boundary_constraints = &air.boundary_constraints(); + let number_of_b_constraints = boundary_constraints.constraints.len(); + + let boundary_zerofiers_inverse_evaluations: Vec>> = + boundary_constraints + .constraints + .chunks(2) + .map(|chunk| { + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; + let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; + let mut evals = domain + .lde_coset_points + .iter() + .map(|eval_point| line(eval_point, first_vanish_point, second_vanish_point)) + .collect::>>(); + FieldElement::inplace_batch_inverse(&mut evals).unwrap(); + evals + }) + .collect::>>>(); + + let boundary_polys_evaluations: Vec>> = boundary_constraints + .constraints + .chunks(2) + .map(|chunk| { + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; + let first_value = first_constraint.value; + let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; + let second_value = second_constraint.value; + let evals = domain + .lde_coset_points + .iter() + .zip(&lde_trace.table.data) + .map(|(eval_point, lde_eval)| { + lde_eval + - interpolant( + first_vanish_point, + second_vanish_point, + first_value, + second_value, + eval_point, + ) + }) + .collect::>>(); + evals + }) + .collect::>>>(); + + // let boundary_polys_evaluations = boundary_constraints + // .constraints + // .iter() + // .map(|constraint| { + // (0..lde_trace.num_rows()) + // .map(|row| { + // let v = lde_trace.table.get(row, constraint.col); + // v - &constraint.value + // }) + // .collect_vec() + // }) + // .collect_vec(); + + // --------------- BEGIN TESTING ---------------------------- + // Interpolate lde trace evaluations. + let l_poly_coefficients = boundary_zerofiers_inverse_evaluations + .iter() + .map(|evals| { + let mut inverse_evals = evals.clone(); + FieldElement::inplace_batch_inverse(&mut inverse_evals).unwrap(); + interpolate_cfft(inverse_evals.to_vec()) + }) + .collect::>>>(); + + let fi_poly_coefficients = boundary_polys_evaluations + .iter() + .map(|evals| interpolate_cfft(evals.to_vec())) + .collect::>>>(); + + // Evaluate lde trace interpolating polynomial in trace domain. + for point in &domain.trace_coset_points { + println!("-----------------------"); + println!( + "L evaluation: {:?}", + evaluate_point(&l_poly_coefficients[0], &point) + ); + println!( + "F-I evaluation: {:?}", + evaluate_point(&fi_poly_coefficients[0], &point) + ); + } + + // --------------- END TESTING ---------------------------- + + let boundary_eval_iter = 0..domain.lde_coset_points.len(); + + let boundary_evaluation: Vec<_> = boundary_eval_iter + .map(|domain_index| { + (0..number_of_b_constraints) + .step_by(2) + .zip(boundary_coefficients) + .fold(FieldElement::zero(), |acc, (constraint_index, beta)| { + acc + &boundary_zerofiers_inverse_evaluations[constraint_index][domain_index] + * beta + * &boundary_polys_evaluations[constraint_index][domain_index] + }) + }) + .collect(); + + // Iterate over all LDE domain and compute the part of the composition polynomial + // related to the transition constraints and add it to the already computed part of the + // boundary constraints. + + let zerofiers_evals = air.transition_zerofier_evaluations(domain); + + // --------------- BEGIN TESTING ---------------------------- + // Interpolate lde trace evaluations. + // let zerofier_poly_coefficients = zerofiers_evals + // .iter() + // .map(|evals| interpolate_cfft(evals.to_vec())) + // .collect::>>>(); + + // // Evaluate lde trace interpolating polynomial in trace domain. + // // This should print all zeroes except in the end exceptions points. + // for point in &domain.trace_coset_points { + // println!( + // "{:?}", + // evaluate_point(&zerofier_poly_coefficients[0], &point) + // ); + // } + // --------------- END TESTING ---------------------------- + + let evaluations_t_iter = 0..domain.lde_coset_points.len(); + + let evaluations_t = evaluations_t_iter + .zip(boundary_evaluation) + .map(|(i, boundary)| { + let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); + + // Compute all the transition constraints at this point of the LDE domain. + let evaluations_transition = air.compute_transition_prover(&frame); + + // Add each term of the transition constraints to the composition polynomial, including the zerofier, + // the challenge and the exemption polynomial if it is necessary. + let acc_transition = itertools::izip!( + evaluations_transition, + &zerofiers_evals, + transition_coefficients + ) + .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { + acc + &zerof_eval[i] * eval * beta + }); + + acc_transition + boundary + }) + .collect(); + + evaluations_t +} diff --git a/provers/circle_stark/src/constraints/boundary.rs b/provers/circle_stark/src/constraints/boundary.rs index c64a84c5f..5791a7b0d 100644 --- a/provers/circle_stark/src/constraints/boundary.rs +++ b/provers/circle_stark/src/constraints/boundary.rs @@ -5,6 +5,8 @@ use lambdaworks_math::{ polynomial::Polynomial, }; +use super::evaluator::line; + #[derive(Debug)] /// Represents a boundary constraint that must hold in an execution /// trace: @@ -113,30 +115,33 @@ impl BoundaryConstraints { .collect() } - /// Evaluate the zerofier of the boundary constraints for a column. The result is the - /// multiplication of each zerofier that evaluates to zero in the domain - /// values where the boundary constraints must hold. - /// - /// Example: If there are boundary conditions in the third and fifth steps, - /// then the zerofier will be f(x, y) = ( ((x, y) + p3.conjugate()).x - 1 ) * ( ((x, y) + p5.conjugate()).x - 1 ) - /// (eval_point + vanish_point.conjugate()).x - FieldElement::::one() - /// TODO: Optimize this function so we don't need to look up and indexes in the coset vector and clone its value. + /// Given a column, it returns for each boundary constraint in that column, the corresponding evaluation + /// in all the `eval_points`. + /// We assume that there are an even number of boundary contrainsts in the column `col`. pub fn evaluate_zerofier( &self, trace_coset: &Vec>, col: usize, - eval_point: &CirclePoint, - ) -> FieldElement { - self.steps(col).into_iter().fold( - FieldElement::::one(), - |zerofier, step| { - let vanish_point = trace_coset[step].clone(); - let evaluation = (eval_point + vanish_point.conjugate()).x - - FieldElement::::one(); - // TODO: Implement the MulAssign trait for Polynomials? - zerofier * evaluation - }, - ) + eval_points: &Vec>, + ) -> Vec>> { + self.constraints + .iter() + .filter(|constraint| constraint.col == col) + .chunks(2) + .into_iter() + .map(|chunk| { + let chunk: Vec<_> = chunk.collect(); + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &trace_coset[first_constraint.step]; + let second_vanish_point = &trace_coset[second_constraint.step]; + + eval_points + .iter() + .map(|eval_point| line(eval_point, &first_vanish_point, &second_vanish_point)) + .collect() + }) + .collect() } } @@ -153,26 +158,25 @@ mod test { use super::*; #[test] - fn zerofier_is_the_correct_one() { + fn simple_fibonacci_boundary_zerofiers() { let one = FieldElement::::one(); // Fibonacci constraints: - // * a0 = 1 - // * a1 = 1 - // * a7 = 32 + // a0 = 1 + // a1 = 1 let a0 = BoundaryConstraint::new_simple(0, one); let a1 = BoundaryConstraint::new_simple(1, one); - let result = BoundaryConstraint::new_simple(7, FieldElement::::from(32)); let trace_coset = Coset::get_coset_points(&Coset::new_standard(3)); - let eval_point = CirclePoint::::GENERATOR * 2; - let a0_zerofier = (&eval_point + &trace_coset[0].clone().conjugate()).x - one; - let a1_zerofier = (&eval_point + &trace_coset[1].clone().conjugate()).x - one; - let res_zerofier = (&eval_point + &trace_coset[7].clone().conjugate()).x - one; - let expected_zerofier = a0_zerofier * a1_zerofier * res_zerofier; - - let constraints = BoundaryConstraints::from_constraints(vec![a0, a1, result]); - let zerofier = constraints.evaluate_zerofier(&trace_coset, 0, &eval_point); + let eval_points = Coset::get_coset_points(&Coset::new_standard(4)); + + let expected_zerofier: Vec>> = vec![eval_points + .iter() + .map(|point| line(point, &trace_coset[0], &trace_coset[1])) + .collect()]; + + let constraints = BoundaryConstraints::from_constraints(vec![a0, a1]); + let zerofier = constraints.evaluate_zerofier(&trace_coset, 0, &eval_points); assert_eq!(expected_zerofier, zerofier); } diff --git a/provers/circle_stark/src/constraints/mod.rs b/provers/circle_stark/src/constraints/mod.rs index 3811523b5..f6f474a44 100644 --- a/provers/circle_stark/src/constraints/mod.rs +++ b/provers/circle_stark/src/constraints/mod.rs @@ -1,3 +1,4 @@ pub mod boundary; pub mod evaluator; pub mod transition; +pub mod utils; diff --git a/provers/circle_stark/src/constraints/utils.rs b/provers/circle_stark/src/constraints/utils.rs new file mode 100644 index 000000000..b92ddde6b --- /dev/null +++ b/provers/circle_stark/src/constraints/utils.rs @@ -0,0 +1,36 @@ +use lambdaworks_math::{ + circle::point::CirclePoint, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; + +// See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). +// https://github.com/ethereum/research/blob/master/circlestark/line_functions.py#L10 +pub fn line( + eval_point: &CirclePoint, + first_vanish_point: &CirclePoint, + second_vanish_point: &CirclePoint, +) -> FieldElement { + (first_vanish_point.y - second_vanish_point.y) * eval_point.x + + (second_vanish_point.x - first_vanish_point.x) * eval_point.y + + (first_vanish_point.x * second_vanish_point.y + - first_vanish_point.y * second_vanish_point.x) +} + +// See https://vitalik.eth.limo/general/2024/07/23/circlestarks.html (Section: Quetienting). +// https://github.com/ethereum/research/blob/master/circlestark/line_functions.py#L16 +// Evaluates the polybomial I at eval_point. I is the polynomial such that I(point_1) = value_1 and +// I(point_2) = value_2. +pub fn interpolant( + point_1: &CirclePoint, + point_2: &CirclePoint, + value_1: FieldElement, + value_2: FieldElement, + eval_point: &CirclePoint, +) -> FieldElement { + let dx = point_2.x - point_1.x; + let dy = point_2.y - point_1.y; + // CHECK: can dx^2 + dy^2 = 0 even if dx!=0 and dy!=0 ? (using that they are FE of Mersenne31). + let invdist = (dx * dx + dy * dy).inv().unwrap(); + let dot = (eval_point.x - point_1.x) * dx + (eval_point.y - point_1.y) * dy; + value_1 + (value_2 - value_1) * dot * invdist +} diff --git a/provers/circle_stark/src/lib.rs b/provers/circle_stark/src/lib.rs index f4108cf40..999534a5e 100644 --- a/provers/circle_stark/src/lib.rs +++ b/provers/circle_stark/src/lib.rs @@ -1,16 +1,15 @@ pub mod air; pub mod air_context; +pub mod composition_polynomial; pub mod config; pub mod constraints; pub mod domain; +pub mod examples; pub mod frame; pub mod prover; pub mod table; pub mod trace; pub mod vanishing_poly; -pub mod examples; - - #[cfg(test)] -pub mod tests; \ No newline at end of file +pub mod tests; From 664fafa3069589656f7a24e08a9eea3cff425545 Mon Sep 17 00:00:00 2001 From: Nicole Date: Wed, 13 Nov 2024 10:41:42 -0300 Subject: [PATCH 93/93] Add composition polynomial function. Change boundary constraints funtions --- .../src/composition_polynomial.rs | 280 +++++++++--------- .../circle_stark/src/constraints/boundary.rs | 147 +++++++-- 2 files changed, 265 insertions(+), 162 deletions(-) diff --git a/provers/circle_stark/src/composition_polynomial.rs b/provers/circle_stark/src/composition_polynomial.rs index 5586f23b2..2704ee336 100644 --- a/provers/circle_stark/src/composition_polynomial.rs +++ b/provers/circle_stark/src/composition_polynomial.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use lambdaworks_math::{ circle::{ point::CirclePoint, - polynomial::{evaluate_point, interpolate_cfft}, + polynomial::{evaluate_cfft, evaluate_point, interpolate_cfft}, }, field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, }; @@ -18,104 +18,23 @@ pub(crate) fn evaluate_cp( transition_coefficients: &[FieldElement], boundary_coefficients: &[FieldElement], ) -> Vec> { + // >>> First, we compute the part of the Composition Polynomial related to the boundary constraints. + let boundary_constraints = &air.boundary_constraints(); let number_of_b_constraints = boundary_constraints.constraints.len(); + let trace_coset = &domain.trace_coset_points; + let lde_coset = &domain.lde_coset_points; - let boundary_zerofiers_inverse_evaluations: Vec>> = - boundary_constraints - .constraints - .chunks(2) - .map(|chunk| { - let first_constraint = &chunk[0]; - let second_constraint = &chunk[1]; - let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; - let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; - let mut evals = domain - .lde_coset_points - .iter() - .map(|eval_point| line(eval_point, first_vanish_point, second_vanish_point)) - .collect::>>(); - FieldElement::inplace_batch_inverse(&mut evals).unwrap(); - evals - }) - .collect::>>>(); - - let boundary_polys_evaluations: Vec>> = boundary_constraints - .constraints - .chunks(2) - .map(|chunk| { - let first_constraint = &chunk[0]; - let second_constraint = &chunk[1]; - let first_vanish_point = &domain.trace_coset_points[first_constraint.step]; - let first_value = first_constraint.value; - let second_vanish_point = &domain.trace_coset_points[second_constraint.step]; - let second_value = second_constraint.value; - let evals = domain - .lde_coset_points - .iter() - .zip(&lde_trace.table.data) - .map(|(eval_point, lde_eval)| { - lde_eval - - interpolant( - first_vanish_point, - second_vanish_point, - first_value, - second_value, - eval_point, - ) - }) - .collect::>>(); - evals - }) - .collect::>>>(); - - // let boundary_polys_evaluations = boundary_constraints - // .constraints - // .iter() - // .map(|constraint| { - // (0..lde_trace.num_rows()) - // .map(|row| { - // let v = lde_trace.table.get(row, constraint.col); - // v - &constraint.value - // }) - // .collect_vec() - // }) - // .collect_vec(); - - // --------------- BEGIN TESTING ---------------------------- - // Interpolate lde trace evaluations. - let l_poly_coefficients = boundary_zerofiers_inverse_evaluations - .iter() - .map(|evals| { - let mut inverse_evals = evals.clone(); - FieldElement::inplace_batch_inverse(&mut inverse_evals).unwrap(); - interpolate_cfft(inverse_evals.to_vec()) - }) - .collect::>>>(); - - let fi_poly_coefficients = boundary_polys_evaluations - .iter() - .map(|evals| interpolate_cfft(evals.to_vec())) - .collect::>>>(); - - // Evaluate lde trace interpolating polynomial in trace domain. - for point in &domain.trace_coset_points { - println!("-----------------------"); - println!( - "L evaluation: {:?}", - evaluate_point(&l_poly_coefficients[0], &point) - ); - println!( - "F-I evaluation: {:?}", - evaluate_point(&fi_poly_coefficients[0], &point) - ); - } - - // --------------- END TESTING ---------------------------- + // For each pair of boundary constraints, we calculate the denominator's evaluations. + let boundary_zerofiers_inverse_evaluations = + boundary_constraints.evaluate_zerofiers(&trace_coset, &lde_coset); - let boundary_eval_iter = 0..domain.lde_coset_points.len(); + // For each pair of boundary constraints, we calculate the numerator's evaluations. + let boundary_polys_evaluations = + boundary_constraints.evaluate_poly_constraints(&trace_coset, &lde_coset, lde_trace); - let boundary_evaluation: Vec<_> = boundary_eval_iter + // We begin to construct the cp by adding each numerator mulpitlied by the denominator and the beta coefficient. + let cp_boundary: Vec> = (0..lde_coset.len()) .map(|domain_index| { (0..number_of_b_constraints) .step_by(2) @@ -128,53 +47,138 @@ pub(crate) fn evaluate_cp( }) .collect(); - // Iterate over all LDE domain and compute the part of the composition polynomial - // related to the transition constraints and add it to the already computed part of the - // boundary constraints. - - let zerofiers_evals = air.transition_zerofier_evaluations(domain); - - // --------------- BEGIN TESTING ---------------------------- - // Interpolate lde trace evaluations. - // let zerofier_poly_coefficients = zerofiers_evals - // .iter() - // .map(|evals| interpolate_cfft(evals.to_vec())) - // .collect::>>>(); - - // // Evaluate lde trace interpolating polynomial in trace domain. - // // This should print all zeroes except in the end exceptions points. - // for point in &domain.trace_coset_points { - // println!( - // "{:?}", - // evaluate_point(&zerofier_poly_coefficients[0], &point) - // ); - // } - // --------------- END TESTING ---------------------------- - - let evaluations_t_iter = 0..domain.lde_coset_points.len(); - - let evaluations_t = evaluations_t_iter - .zip(boundary_evaluation) - .map(|(i, boundary)| { - let frame = Frame::read_from_lde(lde_trace, i, &air.context().transition_offsets); - - // Compute all the transition constraints at this point of the LDE domain. - let evaluations_transition = air.compute_transition_prover(&frame); - - // Add each term of the transition constraints to the composition polynomial, including the zerofier, - // the challenge and the exemption polynomial if it is necessary. - let acc_transition = itertools::izip!( - evaluations_transition, - &zerofiers_evals, + // >>> Now we compute the part of the CP related to the transition constraints and add it to the already + // computed part of the boundary constraints. + + // For each transition constraint, we calulate its zerofier's evaluations. + let transition_zerofiers_inverse_evaluations = air.transition_zerofier_evaluations(domain); + + // + let cp_evaluations = (0..lde_coset.len()) + .zip(cp_boundary) + .map(|(eval_index, boundary_eval)| { + let frame = + Frame::read_from_lde(lde_trace, eval_index, &air.context().transition_offsets); + let transition_poly_evaluations = air.compute_transition_prover(&frame); + let transition_polys_accumulator = itertools::izip!( + transition_poly_evaluations, + &transition_zerofiers_inverse_evaluations, transition_coefficients ) - .fold(FieldElement::zero(), |acc, (eval, zerof_eval, beta)| { - acc + &zerof_eval[i] * eval * beta - }); - - acc_transition + boundary + .fold( + FieldElement::zero(), + |acc, (transition_eval, zerof_eval, beta)| { + acc + &zerof_eval[eval_index] * transition_eval * beta + }, + ); + transition_polys_accumulator + boundary_eval }) .collect(); - evaluations_t + cp_evaluations +} + +#[cfg(test)] +mod test { + + use crate::examples::simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}; + + use super::*; + + type FE = FieldElement; + + fn build_fibonacci_example() {} + + #[test] + fn boundary_zerofiers_vanish_correctly() { + // Build Fibonacci Example + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 32); + let pub_inputs = FibonacciPublicInputs { + a0: FE::one(), + a1: FE::one(), + }; + let air = FibonacciAIR::new(trace.n_rows(), &pub_inputs); + let boundary_constraints = air.boundary_constraints(); + let domain = Domain::new(&air); + + // Calculate the boundary zerofiers evaluations (L function). + let boundary_zerofiers_inverse_evaluations = boundary_constraints + .evaluate_zerofiers(&domain.trace_coset_points, &domain.lde_coset_points); + + // Interpolate the boundary zerofiers evaluations. + let boundary_zerofiers_coeff = boundary_zerofiers_inverse_evaluations + .iter() + .map(|evals| { + let mut inverse_evals = evals.clone(); + FieldElement::inplace_batch_inverse(&mut inverse_evals).unwrap(); + interpolate_cfft(inverse_evals.to_vec()) + }) + .collect::>>>(); + + // Since simple fibonacci only has one pair of boundary constraints we only check that + // the corresponding polynomial evaluates 0 in the first two coset points and different from 0 + // in the rest of the points. + assert_eq!( + evaluate_point(&boundary_zerofiers_coeff[0], &domain.trace_coset_points[0]), + FE::zero() + ); + + assert_eq!( + evaluate_point(&boundary_zerofiers_coeff[0], &domain.trace_coset_points[1]), + FE::zero() + ); + + for point in domain.trace_coset_points.iter().skip(2) { + assert_ne!( + evaluate_point(&boundary_zerofiers_coeff[0], &point), + FE::zero() + ); + } + } + + #[test] + fn boundary_polys_vanish_correctly() { + // Build Fibonacci Example + let trace = simple_fibonacci::fibonacci_trace([FE::one(), FE::one()], 32); + let pub_inputs = FibonacciPublicInputs { + a0: FE::one(), + a1: FE::one(), + }; + let air = FibonacciAIR::new(trace.n_rows(), &pub_inputs); + let boundary_constraints = air.boundary_constraints(); + let domain = Domain::new(&air); + + // Evaluate each polynomial in the lde domain. + let lde_trace = LDETraceTable::new(trace.table.data.clone(), 1, 1); + + // Calculate boundary polynomials evaluations (the polynomial f - I). + let boundary_polys_evaluations = boundary_constraints.evaluate_poly_constraints( + &domain.trace_coset_points, + &domain.lde_coset_points, + &lde_trace, + ); + + // Interpolate the boundary polynomials evaluations. + let boundary_poly_coeff = boundary_polys_evaluations + .iter() + .map(|evals| interpolate_cfft(evals.to_vec())) + .collect::>>>(); + + // Since simple fibonacci only has one pair of boundary constraints we only check that + // the corresponding polynomial evaluates 0 in the first two coset points and different from 0 + // in the rest of the points. + assert_eq!( + evaluate_point(&boundary_poly_coeff[0], &domain.trace_coset_points[0]), + FE::zero() + ); + + assert_eq!( + evaluate_point(&boundary_poly_coeff[0], &domain.trace_coset_points[1]), + FE::zero() + ); + + for point in domain.trace_coset_points.iter().skip(2) { + assert_ne!(evaluate_point(&boundary_poly_coeff[0], &point), FE::zero()); + } + } } diff --git a/provers/circle_stark/src/constraints/boundary.rs b/provers/circle_stark/src/constraints/boundary.rs index 5791a7b0d..dab5a92e1 100644 --- a/provers/circle_stark/src/constraints/boundary.rs +++ b/provers/circle_stark/src/constraints/boundary.rs @@ -5,7 +5,9 @@ use lambdaworks_math::{ polynomial::Polynomial, }; -use super::evaluator::line; +use crate::trace::LDETraceTable; + +use super::{evaluator::line, utils::interpolant}; #[derive(Debug)] /// Represents a boundary constraint that must hold in an execution @@ -115,34 +117,131 @@ impl BoundaryConstraints { .collect() } - /// Given a column, it returns for each boundary constraint in that column, the corresponding evaluation + /// For every column, it returns for each pair of boundary constraints in that column, the corresponding evaluation /// in all the `eval_points`. - /// We assume that there are an even number of boundary contrainsts in the column `col`. - pub fn evaluate_zerofier( + // We assume that there are an even number of boundary contrainsts in the column `col`. + // TODO: If two columns have a pair of boundary constraints that hold in the same rows, we can optimize + // this function so that we don't calculate the same line evaluations for the two of them. Maybe a hashmap? + pub fn evaluate_zerofiers( &self, trace_coset: &Vec>, - col: usize, eval_points: &Vec>, ) -> Vec>> { - self.constraints - .iter() - .filter(|constraint| constraint.col == col) - .chunks(2) - .into_iter() - .map(|chunk| { - let chunk: Vec<_> = chunk.collect(); - let first_constraint = &chunk[0]; - let second_constraint = &chunk[1]; - let first_vanish_point = &trace_coset[first_constraint.step]; - let second_vanish_point = &trace_coset[second_constraint.step]; - - eval_points - .iter() - .map(|eval_point| line(eval_point, &first_vanish_point, &second_vanish_point)) - .collect() - }) - .collect() + let mut zerofiers_evaluations = Vec::new(); + for col in self.cols_for_boundary() { + self.constraints + .iter() + .filter(|constraint| constraint.col == col) + .chunks(2) + .into_iter() + .map(|chunk| { + let chunk: Vec<_> = chunk.collect(); + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &trace_coset[first_constraint.step]; + let second_vanish_point = &trace_coset[second_constraint.step]; + + let mut boundary_evaluations: Vec> = eval_points + .iter() + .map(|eval_point| { + line(eval_point, &first_vanish_point, &second_vanish_point) + }) + .collect(); + FieldElement::inplace_batch_inverse(&mut boundary_evaluations).unwrap(); + zerofiers_evaluations.push(boundary_evaluations); + }) + .collect() // TODO: We don't use this collect. + } + zerofiers_evaluations } + + /// For every columnm, and every constrain, returns the evaluation on the entire lde domain of the constrain function F(x) - I(x). + // Note this is the numerator of each pair of constrains for the composition polynomial. + pub fn evaluate_poly_constraints( + &self, + trace_coset: &Vec>, + eval_points: &Vec>, + lde_trace: &LDETraceTable, + ) -> Vec>> { + let mut poly_evaluations = Vec::new(); + for col in self.cols_for_boundary() { + self.constraints + .iter() + .filter(|constraint| constraint.col == col) + .chunks(2) + .into_iter() + .map(|chunk| { + let chunk: Vec<_> = chunk.collect(); + let first_constraint = &chunk[0]; + let second_constraint = &chunk[1]; + let first_vanish_point = &trace_coset[first_constraint.step]; + let first_value = first_constraint.value; + let second_vanish_point = &trace_coset[second_constraint.step]; + let second_value = second_constraint.value; + let boundary_evaluations = eval_points + .iter() + .zip(&lde_trace.table.columns()[col]) + .map(|(eval_point, lde_eval)| { + lde_eval + - interpolant( + &first_vanish_point, + &second_vanish_point, + first_value, + second_value, + eval_point, + ) + }) + .collect::>>(); + poly_evaluations.push(boundary_evaluations); + }) + .collect() // TODO: We are not using this collect. + } + poly_evaluations + } + + // /// For every columnm, and every constrain, returns the evaluation on the entire lde domain of the constrain function F(x) - I(x). + // // Note this is the numerator of each pair of constrains for the composition polynomial. + // pub fn evaluate_poly_constraints( + // &self, + // trace_coset: &Vec>, + // eval_points: &Vec>, + // lde_trace: &LDETraceTable, + // ) -> Vec>>> { + // self.cols_for_boundary() + // .iter() + // .map(|col| { + // self.constraints + // .iter() + // .filter(|constraint| constraint.col == *col) + // .chunks(2) + // .into_iter() + // .map(|chunk| { + // let chunk: Vec<_> = chunk.collect(); + // let first_constraint = &chunk[0]; + // let second_constraint = &chunk[1]; + // let first_vanish_point = &trace_coset[first_constraint.step]; + // let first_value = first_constraint.value; + // let second_vanish_point = &trace_coset[second_constraint.step]; + // let second_value = second_constraint.value; + // eval_points + // .iter() + // .zip(&lde_trace.table.columns()[*col]) + // .map(|(eval_point, lde_eval)| { + // lde_eval + // - interpolant( + // &first_vanish_point, + // &second_vanish_point, + // first_value, + // second_value, + // eval_point, + // ) + // }) + // .collect::>>() + // }) + // .collect() + // }) + // .collect() + // } } #[cfg(test)] @@ -176,7 +275,7 @@ mod test { .collect()]; let constraints = BoundaryConstraints::from_constraints(vec![a0, a1]); - let zerofier = constraints.evaluate_zerofier(&trace_coset, 0, &eval_points); + let zerofier = constraints.evaluate_zerofiers(&trace_coset, &eval_points); assert_eq!(expected_zerofier, zerofier); }