-
Notifications
You must be signed in to change notification settings - Fork 2
Multi-Threading for Sample_z #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
86539c7
6585233
365e917
e8d7908
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,10 @@ use crate::{ | |
traits::{GetNumColumns, GetNumRows, Pow}, | ||
}; | ||
use rand::RngCore; | ||
use rayon::{ | ||
current_num_threads, | ||
iter::{IntoParallelIterator, ParallelIterator}, | ||
}; | ||
use serde::Serialize; | ||
use std::collections::HashMap; | ||
|
||
|
@@ -172,6 +176,65 @@ impl DiscreteGaussianIntegerSampler { | |
} | ||
} | ||
} | ||
|
||
/// Chooses `nr_samples` samples according to the discrete Gaussian distribution out of | ||
/// `[lower_bound , lower_bound + interval_size ]`. | ||
/// | ||
/// This function implements a multi-threaded version of [`DiscreteGaussianIntegerSampler::sample_z`] | ||
/// that simply samples `nr_samples` many entries. | ||
/// It first considers the number of available threads. | ||
/// For each thread, a single sampler will be cloned from the origin (to ensure memory safety), | ||
/// and if there is only one available thread, then we will not clone the sampler, ensuring that the actual | ||
/// sampler will be updated with new hash values. | ||
/// | ||
/// Parameters: | ||
/// - `nr_samples`: the number of `sample_z` samples that should be computed. | ||
/// | ||
/// # Examples | ||
/// ``` | ||
/// use qfall_math::{integer::Z, rational::Q}; | ||
/// use qfall_math::utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler; | ||
/// let n = Z::from(1024); | ||
/// let center = Q::ZERO; | ||
/// let gaussian_parameter = Q::ONE; | ||
/// | ||
/// let mut dgis = DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); | ||
/// | ||
/// let samples_5 = dgis.sample_z_multiple(5); | ||
/// assert_eq!(samples_5.len(), 5) | ||
/// ``` | ||
/// | ||
/// # Panics ... | ||
/// - if `nr_samples` is negative | ||
pub fn sample_z_multiple(&mut self, nr_samples: i64) -> Vec<Z> { | ||
let nr_threads = current_num_threads(); | ||
let nr_samples = nr_samples as usize; | ||
if nr_threads == 1 || nr_samples < 10 { | ||
// no multithreading | ||
(0..nr_samples).map(|_| self.sample_z()).collect() | ||
} else { | ||
// with multithreading | ||
let entries_per_thread = nr_samples / nr_threads; | ||
let remainder = nr_samples % nr_threads; | ||
(0..nr_threads) | ||
.into_par_iter() | ||
.map(|thread_i| { | ||
let mut dgis_thread = self.clone(); | ||
let entries_thread_i = if thread_i < remainder { | ||
entries_per_thread + 1 | ||
} else { | ||
entries_per_thread | ||
}; | ||
(0..entries_thread_i) | ||
.map(|_| dgis_thread.sample_z()) | ||
.collect() | ||
}) | ||
.reduce(Vec::new, |mut a, mut b| { | ||
a.append(&mut b); | ||
a | ||
}) | ||
Comment on lines
+232
to
+235
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't looked at the flamegraph, but considering my current experience with our library, setting the values of the matrix has a significant overhead. Furthermore, joining and iterating vectors shouldn't be the fastest thing in the world. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will check later, but the join/reduce is probably efficient, as it moves the values. |
||
} | ||
} | ||
} | ||
|
||
/// Computes the value of the Gaussian function for `x`. | ||
|
@@ -800,3 +863,47 @@ mod test_sample_d { | |
let _ = sample_d_precomputed_gso(&basis, &false_gso, &n, ¢er, &Q::from(5)).unwrap(); | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test_sample_z_multiple { | ||
use crate::{ | ||
integer::Z, rational::Q, utils::sample::discrete_gauss::DiscreteGaussianIntegerSampler, | ||
}; | ||
|
||
/// Ensure that the function outputs the correct number of samples | ||
#[test] | ||
fn correct_number_of_samples() { | ||
let n = Z::from(1024); | ||
let center = Q::ZERO; | ||
let gaussian_parameter = Q::ONE; | ||
|
||
let mut dgis = | ||
DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); | ||
|
||
let samples_0 = dgis.sample_z_multiple(0); | ||
let samples_1 = dgis.sample_z_multiple(1); | ||
let samples_10 = dgis.sample_z_multiple(10); | ||
let samples_110 = dgis.sample_z_multiple(110); | ||
let samples_12410 = dgis.sample_z_multiple(12410); | ||
|
||
assert_eq!(0, samples_0.len()); | ||
assert_eq!(1, samples_1.len()); | ||
assert_eq!(10, samples_10.len()); | ||
assert_eq!(110, samples_110.len()); | ||
assert_eq!(12410, samples_12410.len()); | ||
} | ||
|
||
/// Ensure that the function does not allow for negative number of samples | ||
#[test] | ||
#[should_panic] | ||
fn panic_if_negative_nr_samples() { | ||
let n = Z::from(1024); | ||
let center = Q::ZERO; | ||
let gaussian_parameter = Q::ONE; | ||
|
||
let mut dgis = | ||
DiscreteGaussianIntegerSampler::init(&n, ¢er, &gaussian_parameter).unwrap(); | ||
|
||
let _ = dgis.sample_z_multiple(-1); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to provide another idea to make things quicker (although this might not be
rayon
's way of parallelising things).Currently, you split the workload into similar-sized buckets - implicitely making the assumption that each bucket will roughly take the same duration on each thread. This assumption should be correct in this case for larger bucket sizes, but for smaller sizes, it might be quicker to just submit tasks to a pool of threads, where each thread collects a new task once it has finished the current one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked again, but I was not able to find a better solution for it, and given that the function will probably not be called with small numbers of samples - I think the current implementation is reasonable.
The problem with the dynamic approach is that I was not able to find a good way to also distribute the integer sampler, and additionally, this also provides an overhead with more threadmanagement, which might also increase the runtime due to the dynamic distribution of tasks.