From 86539c73004eca7b1a9ccc8172b2bcad3374dced Mon Sep 17 00:00:00 2001 From: Marvin Beckmann Date: Wed, 26 Feb 2025 10:37:03 +0100 Subject: [PATCH 1/4] implement an algorithm to get several sample_z samples using multi-threading --- Cargo.toml | 1 + src/integer/mat_z/sample/discrete_gauss.rs | 5 +- .../poly_over_z/sample/discrete_gauss.rs | 4 +- .../mat_zq/sample/discrete_gauss.rs | 5 +- .../poly_over_zq/sample/discrete_gauss.rs | 4 +- src/utils/sample/discrete_gauss.rs | 114 ++++++++++++++++++ 6 files changed, 125 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4fc932b79..354f48241 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ thiserror = "2.0" lazy_static = "1.4" probability = "0.20.3" derive_more = { version = "2.0.1", features = ["display"] } +rayon = "1.10.0" [profile.bench] debug = true diff --git a/src/integer/mat_z/sample/discrete_gauss.rs b/src/integer/mat_z/sample/discrete_gauss.rs index 60e6cefd2..cd47aa51f 100644 --- a/src/integer/mat_z/sample/discrete_gauss.rs +++ b/src/integer/mat_z/sample/discrete_gauss.rs @@ -63,10 +63,11 @@ impl MatZ { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(out.get_num_columns() * out.get_num_rows()); + for row in 0..out.get_num_rows() { for col in 0..out.get_num_columns() { - let sample = dgis.sample_z(); - out.set_entry(row, col, sample)?; + out.set_entry(row, col, entries.pop().unwrap()).unwrap(); } } diff --git a/src/integer/poly_over_z/sample/discrete_gauss.rs b/src/integer/poly_over_z/sample/discrete_gauss.rs index 0923b8809..97dca160b 100644 --- a/src/integer/poly_over_z/sample/discrete_gauss.rs +++ b/src/integer/poly_over_z/sample/discrete_gauss.rs @@ -62,9 +62,9 @@ impl PolyOverZ { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(max_degree + 1); for index in 0..=max_degree { - let sample = dgis.sample_z(); - poly.set_coeff(index, &sample)?; + poly.set_coeff(index, entries.pop().unwrap())?; } Ok(poly) } diff --git a/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs b/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs index aa4e1b399..484d393b1 100644 --- a/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs +++ b/src/integer_mod_q/mat_zq/sample/discrete_gauss.rs @@ -66,10 +66,11 @@ impl MatZq { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(out.get_num_columns() * out.get_num_rows()); + for row in 0..out.get_num_rows() { for col in 0..out.get_num_columns() { - let sample = dgis.sample_z(); - out.set_entry(row, col, sample).unwrap(); + out.set_entry(row, col, entries.pop().unwrap()).unwrap(); } } diff --git a/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs b/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs index fb79e239e..c22d281a0 100644 --- a/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs +++ b/src/integer_mod_q/poly_over_zq/sample/discrete_gauss.rs @@ -67,9 +67,9 @@ impl PolyOverZq { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &s)?; + let mut entries = dgis.sample_z_multiple(max_degree + 1); for index in 0..=max_degree { - let sample = dgis.sample_z(); - poly.set_coeff(index, &sample)?; + poly.set_coeff(index, entries.pop().unwrap())?; } Ok(poly) } diff --git a/src/utils/sample/discrete_gauss.rs b/src/utils/sample/discrete_gauss.rs index 3e2ec4a06..4d01a3c0c 100644 --- a/src/utils/sample/discrete_gauss.rs +++ b/src/utils/sample/discrete_gauss.rs @@ -24,6 +24,10 @@ use crate::{ traits::{GetNumColumns, GetNumRows, Pow}, }; use rand::RngCore; +use rayon::{ + current_num_threads, + iter::{IntoParallelIterator, ParallelIterator}, +}; use serde::Serialize; use std::collections::HashMap; @@ -172,6 +176,72 @@ impl DiscreteGaussianIntegerSampler { } } } + + /// Chooses `nr_samples` samples according to the discrete Gaussian distribution out of + /// `[lower_bound , lower_bound + interval_size ]`. + /// + /// This function implements a multi-threaded version of [`DiscreteGaussianIntegerSampler::sample_z`] + /// that simply samples `nr_samples` many entries. + /// It first considers the number of available threads. + /// For each thread, a single sampler will be cloned from the origin (to ensure memory safeness), + /// and if there is only one avalable thread, then we will not clone the sampler, ensuring that the actual + /// sampler will be updated with new hash values. + /// + /// Parameters: + /// - `nr_samples`: the number of `sample_z` samples that should be computed. + /// + /// # Examples + /// ``` + /// use qfall_math::{integer::Z, rational::Q}; + /// use qfall_math::utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler; + /// let n = Z::from(1024); + /// let center = Q::ZERO; + /// let gaussian_parameter = Q::ONE; + /// + /// let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + /// + /// let samples_5 = dgis.sample_z_multiple(5); + /// assert_eq!(samples_5.len(), 5) + /// ``` + /// + /// # Panics ... + /// - if `nr_samples` is negative + pub fn sample_z_multiple(&mut self, nr_samples: i64) -> Vec { + let nr_threads = current_num_threads(); + let nr_samples = nr_samples as usize; + if nr_threads == 1 { + // no multithreading + (0..nr_samples) + .into_iter() + .map(|_| self.sample_z()) + .collect() + } else { + // with multithreading + let entries_per_thread = nr_samples / nr_threads; + let remainder = nr_samples % nr_threads; + (0..nr_threads) + .into_par_iter() + .map(|thread_i| { + let mut dgis_thread = self.clone(); + let entries_thread_i = if thread_i < remainder { + entries_per_thread + 1 + } else { + entries_per_thread + }; + (0..entries_thread_i) + .into_iter() + .map(|_| dgis_thread.sample_z()) + .collect() + }) + .reduce( + || Vec::new(), + |mut a, mut b| { + a.append(&mut b); + a + }, + ) + } + } } /// Computes the value of the Gaussian function for `x`. @@ -800,3 +870,47 @@ mod test_sample_d { let _ = sample_d_precomputed_gso(&basis, &false_gso, &n, ¢er, &Q::from(5)).unwrap(); } } + +#[cfg(test)] +mod test_sample_z_multiple { + use crate::{ + integer::Z, rational::Q, utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler, + }; + + /// Ensure that the function outputs the correct number of samples + #[test] + fn correct_number_of_samples() { + let n = Z::from(1024); + let center = Q::ZERO; + let gaussian_parameter = Q::ONE; + + let mut dgis = + DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + + let samples_0 = dgis.sample_z_multiple(0); + let samples_1 = dgis.sample_z_multiple(1); + let samples_10 = dgis.sample_z_multiple(10); + let samples_110 = dgis.sample_z_multiple(110); + let samples_12410 = dgis.sample_z_multiple(12410); + + assert_eq!(0, samples_0.len()); + assert_eq!(1, samples_1.len()); + assert_eq!(10, samples_10.len()); + assert_eq!(110, samples_110.len()); + assert_eq!(12410, samples_12410.len()); + } + + /// Ensure that the function does not allow for negative number of samples + #[test] + #[should_panic] + fn panic_if_negative_nr_samples() { + let n = Z::from(1024); + let center = Q::ZERO; + let gaussian_parameter = Q::ONE; + + let mut dgis = + DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); + + let samples_0 = dgis.sample_z_multiple(-1); + } +} From 65852339eb55415c824e9ed4ed2bcd110a608030 Mon Sep 17 00:00:00 2001 From: Marvin Beckmann Date: Thu, 27 Feb 2025 08:56:54 +0100 Subject: [PATCH 2/4] add benchmark for small matZ discrete Gaussian sampling --- benches/sample_z.rs | 38 +++++++++++++++++++++++++++++- src/utils/sample/discrete_gauss.rs | 4 ++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/benches/sample_z.rs b/benches/sample_z.rs index ba0902536..1dae33300 100644 --- a/benches/sample_z.rs +++ b/benches/sample_z.rs @@ -42,6 +42,39 @@ pub fn bench_sample_z_narrow(c: &mut Criterion) { c.bench_function("SampleZ narrow 10,000", |b| b.iter(sample_z_narrow)); } +/// benchmark creating a matrix of size 1x9 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by9_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x9 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 9, &n, ¢er, &s).unwrap()) + }); +} + +/// benchmark creating a matrix of size 1x10 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by10_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x10 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 10, &n, ¢er, &s).unwrap()) + }); +} + +/// benchmark creating a matrix of size 1x1 sampled by a comparatively narrow discrete Gaussian distribution. +pub fn bench_sample_z_1by1_matrix(c: &mut Criterion) { + let n = Z::from(1000); + let center = Q::from(0); + let s = Q::from(100); + + c.bench_function("SampleZ 1x1 matrix", |b| { + b.iter(|| MatZ::sample_discrete_gauss(1, 1, &n, ¢er, &s).unwrap()) + }); +} + /// benchmark creating a single integer sampled by a comparatively wide discrete Gaussian distribution. pub fn bench_sample_z_wide_single(c: &mut Criterion) { /// Create a single integer sampled by a comparatively wide discrete Gaussian distribution. @@ -75,5 +108,8 @@ criterion_group!( bench_sample_z_wide, bench_sample_z_narrow, bench_sample_z_wide_single, - bench_sample_z_narrow_single + bench_sample_z_narrow_single, + bench_sample_z_1by1_matrix, + bench_sample_z_1by9_matrix, + bench_sample_z_1by10_matrix ); diff --git a/src/utils/sample/discrete_gauss.rs b/src/utils/sample/discrete_gauss.rs index 4d01a3c0c..c79acdc3e 100644 --- a/src/utils/sample/discrete_gauss.rs +++ b/src/utils/sample/discrete_gauss.rs @@ -209,7 +209,7 @@ impl DiscreteGaussianIntegerSampler { pub fn sample_z_multiple(&mut self, nr_samples: i64) -> Vec { let nr_threads = current_num_threads(); let nr_samples = nr_samples as usize; - if nr_threads == 1 { + if nr_threads == 1 || nr_samples < 10 { // no multithreading (0..nr_samples) .into_iter() @@ -911,6 +911,6 @@ mod test_sample_z_multiple { let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); - let samples_0 = dgis.sample_z_multiple(-1); + let _ = dgis.sample_z_multiple(-1); } } From 365e9175b497c2349aab93fc22f98a98f7e82c1c Mon Sep 17 00:00:00 2001 From: Marvin Beckmann Date: Thu, 27 Feb 2025 08:58:22 +0100 Subject: [PATCH 3/4] fix clippy warnings in multi-threaded implementation of sample_z --- src/utils/sample/discrete_gauss.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/utils/sample/discrete_gauss.rs b/src/utils/sample/discrete_gauss.rs index c79acdc3e..b4a6d4ccc 100644 --- a/src/utils/sample/discrete_gauss.rs +++ b/src/utils/sample/discrete_gauss.rs @@ -211,10 +211,7 @@ impl DiscreteGaussianIntegerSampler { let nr_samples = nr_samples as usize; if nr_threads == 1 || nr_samples < 10 { // no multithreading - (0..nr_samples) - .into_iter() - .map(|_| self.sample_z()) - .collect() + (0..nr_samples).map(|_| self.sample_z()).collect() } else { // with multithreading let entries_per_thread = nr_samples / nr_threads; @@ -229,17 +226,13 @@ impl DiscreteGaussianIntegerSampler { entries_per_thread }; (0..entries_thread_i) - .into_iter() .map(|_| dgis_thread.sample_z()) .collect() }) - .reduce( - || Vec::new(), - |mut a, mut b| { - a.append(&mut b); - a - }, - ) + .reduce(Vec::new, |mut a, mut b| { + a.append(&mut b); + a + }) } } } From e8d7908f5eb7e6c98413a5c46b0d9bab1247a5d4 Mon Sep 17 00:00:00 2001 From: Marvin Beckmann Date: Mon, 3 Mar 2025 09:24:13 +0100 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Jan Niklas Siemer --- src/utils/sample/discrete_gauss.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/sample/discrete_gauss.rs b/src/utils/sample/discrete_gauss.rs index b4a6d4ccc..2f3b55b9d 100644 --- a/src/utils/sample/discrete_gauss.rs +++ b/src/utils/sample/discrete_gauss.rs @@ -183,8 +183,8 @@ impl DiscreteGaussianIntegerSampler { /// This function implements a multi-threaded version of [`DiscreteGaussianIntegerSampler::sample_z`] /// that simply samples `nr_samples` many entries. /// It first considers the number of available threads. - /// For each thread, a single sampler will be cloned from the origin (to ensure memory safeness), - /// and if there is only one avalable thread, then we will not clone the sampler, ensuring that the actual + /// For each thread, a single sampler will be cloned from the origin (to ensure memory safety), + /// and if there is only one available thread, then we will not clone the sampler, ensuring that the actual /// sampler will be updated with new hash values. /// /// Parameters: