diff --git a/Cargo.toml b/Cargo.toml index 4fc932b79..354f48241 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ thiserror = "2.0" lazy_static = "1.4" probability = "0.20.3" derive_more = { version = "2.0.1", features = ["display"] } +rayon = "1.10.0" [profile.bench] debug = true diff --git a/benches/sample_z.rs b/benches/sample_z.rs index ba0902536..1dae33300 100644 --- a/benches/sample_z.rs +++ b/benches/sample_z.rs @@ -42,6 +42,39 @@ pub fn bench_sample_z_narrow(c: &mut Criterion) { c.bench_function("SampleZ narrow 10,000", |b| b.iter(sample_z_narrow)); } +/// benchmark creating a matrix of size 1x9 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by9_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x9 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 9, &n, ¢er, &s).unwrap()) + }); +} + +/// benchmark creating a matrix of size 1x10 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by10_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x10 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 10, &n, ¢er, &s).unwrap()) + }); +} + +/// benchmark creating a matrix of size 1x1 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by1_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x1 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 1, &n, ¢er, &s).unwrap()) + }); +} + /// benchmark creating a single integer sampled by a comparatively wide discrete Gaussian distribution. pub fn bench_sample_z_wide_single(c: &mut Criterion) { /// Create a single integer sampled by a comparatively wide discrete Gaussian distribution. @@ -75,5 +108,8 @@ criterion_group!( bench_sample_z_wide, bench_sample_z_narrow, bench_sample_z_wide_single, - bench_sample_z_narrow_single + bench_sample_z_narrow_single, + bench_sample_z_1by1_matrix, + bench_sample_z_1by9_matrix, + bench_sample_z_1by10_matrix ); diff --git a/src/integer/mat_z/sample/discrete_gauss.rs b/src/integer/mat_z/sample/discrete_gauss.rs index 60e6cefd2..cd47aa51f 100644 --- a/src/integer/mat_z/sample/discrete_gauss.rs +++ b/src/integer/mat_z/sample/discrete_gauss.rs @@ -63,10 +63,11 @@ impl MatZ { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(out.get_num_columns() * out.get_num_rows()); + for row in 0..out.get_num_rows() { for col in 0..out.get_num_columns() { - let sample = dgis.sample_z(); - out.set_entry(row, col, sample)?; + out.set_entry(row, col, entries.pop().unwrap()).unwrap(); } } diff --git a/src/integer/poly_over_z/sample/discrete_gauss.rs b/src/integer/poly_over_z/sample/discrete_gauss.rs index 0923b8809..97dca160b 100644 --- a/src/integer/poly_over_z/sample/discrete_gauss.rs +++ b/src/integer/poly_over_z/sample/discrete_gauss.rs @@ -62,9 +62,9 @@ impl PolyOverZ { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(max_degree + 1); for index in 0..=max_degree { - let sample = dgis.sample_z(); - poly.set_coeff(index, &sample)?; + poly.set_coeff(index, entries.pop().unwrap())?; } Ok(poly) } diff --git a/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs b/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs index aa4e1b399..484d393b1 100644 --- a/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs +++ b/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs @@ -66,10 +66,11 @@ impl MatZq { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(out.get_num_columns() * out.get_num_rows()); + for row in 0..out.get_num_rows() { for col in 0..out.get_num_columns() { - let sample = dgis.sample_z(); - out.set_entry(row, col, sample).unwrap(); + out.set_entry(row, col, entries.pop().unwrap()).unwrap(); } } diff --git a/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs b/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs index fb79e239e..c22d281a0 100644 --- a/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs +++ b/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs @@ -67,9 +67,9 @@ impl PolyOverZq { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(max_degree + 1); for index in 0..=max_degree { - let sample = dgis.sample_z(); - poly.set_coeff(index, &sample)?; + poly.set_coeff(index, entries.pop().unwrap())?; } Ok(poly) } diff --git a/src/utils/sample/discrete_gauss.rs b/src/utils/sample/discrete_gauss.rs index 3e2ec4a06..2f3b55b9d 100644 --- a/src/utils/sample/discrete_gauss.rs +++ b/src/utils/sample/discrete_gauss.rs @@ -24,6 +24,10 @@ use crate::{ traits::{GetNumColumns, GetNumRows, Pow}, }; use rand::RngCore; +use rayon::{ + current_num_threads, + iter::{IntoParallelIterator, ParallelIterator}, +}; use serde::Serialize; use std::collections::HashMap; @@ -172,6 +176,65 @@ impl DiscreteGaussianIntegerSampler { } } } + + /// Chooses `nr_samples` samples according to the discrete Gaussian distribution out of + /// `[lower_bound , lower_bound + interval_size ]`. + /// + /// This function implements a multi-threaded version of [`DiscreteGaussianIntegerSampler::sample_z`] + /// that simply samples `nr_samples` many entries. + /// It first considers the number of available threads. + /// For each thread, a single sampler will be cloned from the origin (to ensure memory safety), + /// and if there is only one available thread, then we will not clone the sampler, ensuring that the actual + /// sampler will be updated with new hash values. + /// + /// Parameters: + /// - `nr_samples`: the number of `sample_z` samples that should be computed. + /// + /// # Examples + /// ``` + /// use qfall_math::{integer::Z, rational::Q}; + /// use qfall_math::utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler; + /// let n = Z::from(1024); + /// let center = Q::ZERO; + /// let gaussian_parameter = Q::ONE; + /// + /// let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + /// + /// let samples_5 = dgis.sample_z_multiple(5); + /// assert_eq!(samples_5.len(), 5) + /// ``` + /// + /// # Panics ... + /// - if `nr_samples` is negative + pub fn sample_z_multiple(&mut self, nr_samples: i64) -> Vec { + let nr_threads = current_num_threads(); + let nr_samples = nr_samples as usize; + if nr_threads == 1 || nr_samples < 10 { + // no multithreading + (0..nr_samples).map(|_| self.sample_z()).collect() + } else { + // with multithreading + let entries_per_thread = nr_samples / nr_threads; + let remainder = nr_samples % nr_threads; + (0..nr_threads) + .into_par_iter() + .map(|thread_i| { + let mut dgis_thread = self.clone(); + let entries_thread_i = if thread_i < remainder { + entries_per_thread + 1 + } else { + entries_per_thread + }; + (0..entries_thread_i) + .map(|_| dgis_thread.sample_z()) + .collect() + }) + .reduce(Vec::new, |mut a, mut b| { + a.append(&mut b); + a + }) + } + } } /// Computes the value of the Gaussian function for `x`. @@ -800,3 +863,47 @@ mod test_sample_d { let _ = sample_d_precomputed_gso(&basis, &false_gso, &n, ¢er, &Q::from(5)).unwrap(); } } + +#[cfg(test)] +mod test_sample_z_multiple { + use crate::{ + integer::Z, rational::Q, utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler, + }; + + /// Ensure that the function outputs the correct number of samples + #[test] + fn correct_number_of_samples() { + let n = Z::from(1024); + let center = Q::ZERO; + let gaussian_parameter = Q::ONE; + + let mut dgis = + DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + + let samples_0 = dgis.sample_z_multiple(0); + let samples_1 = dgis.sample_z_multiple(1); + let samples_10 = dgis.sample_z_multiple(10); + let samples_110 = dgis.sample_z_multiple(110); + let samples_12410 = dgis.sample_z_multiple(12410); + + assert_eq!(0, samples_0.len()); + assert_eq!(1, samples_1.len()); + assert_eq!(10, samples_10.len()); + assert_eq!(110, samples_110.len()); + assert_eq!(12410, samples_12410.len()); + } + + /// Ensure that the function does not allow for negative number of samples + #[test] + #[should_panic] + fn panic_if_negative_nr_samples() { + let n = Z::from(1024); + let center = Q::ZERO; + let gaussian_parameter = Q::ONE; + + let mut dgis = + DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + + let _ = dgis.sample_z_multiple(-1); + } +}