Skip to content

[WIP] Allow NUTS to do eager evaluation on forward and backward trajectory in parallel #3103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
19 changes: 13 additions & 6 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
@@ -12,18 +12,25 @@ namespace mcmc {
* with a Gaussian-Euclidean disintegration and adaptive
* diagonal metric and adaptive step size
*/
template <class Model, class BaseRNG>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
template <class Model, class BaseRNG, bool ParallelBase = false>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG, ParallelBase>,
public stepsize_var_adapter {
public:
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<!ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, BaseRNG& rng)
: diag_e_nuts<Model, BaseRNG>(model, rng),
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, rng),
stepsize_var_adapter(model.num_params_r()) {}

~adapt_diag_e_nuts() {}
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, std::vector<BaseRNG>& thread_rngs)
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, thread_rngs),
stepsize_var_adapter(model.num_params_r()) {}

sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG>::transition(init_sample, logger);
inline sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG, ParallelBase>::transition(
init_sample, logger);

if (this->adapt_flag_) {
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
780 changes: 780 additions & 0 deletions src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp
Original file line number Diff line number Diff line change
@@ -2,22 +2,32 @@
#define STAN_MCMC_HMC_NUTS_DIAG_E_NUTS_HPP

#include <stan/mcmc/hmc/nuts/base_nuts.hpp>
#include <stan/mcmc/hmc/nuts/base_parallel_nuts.hpp>
#include <stan/mcmc/hmc/hamiltonians/diag_e_point.hpp>
#include <stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp>
#include <stan/mcmc/hmc/integrators/expl_leapfrog.hpp>

namespace stan {
namespace mcmc {

/**
* The No-U-Turn sampler (NUTS) with multinomial sampling
* with a Gaussian-Euclidean disintegration and diagonal metric
*/
template <class Model, class BaseRNG>
class diag_e_nuts
: public base_nuts<Model, diag_e_metric, expl_leapfrog, BaseRNG> {
template <class Model, class BaseRNG, bool ParallelBase = false>
class diag_e_nuts : public base_nuts_ct<ParallelBase, Model, diag_e_metric,
expl_leapfrog, BaseRNG> {
using base_nuts_t = base_nuts_ct<ParallelBase, Model, diag_e_metric,
expl_leapfrog, BaseRNG>;

public:
diag_e_nuts(const Model& model, BaseRNG& rng)
: base_nuts<Model, diag_e_metric, expl_leapfrog, BaseRNG>(model, rng) {}
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<!ParallelBase_>* = nullptr>
diag_e_nuts(const Model& model, BaseRNG& rng) : base_nuts_t(model, rng) {}
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<ParallelBase_>* = nullptr>
diag_e_nuts(const Model& model, std::vector<BaseRNG>& thread_rngs)
: base_nuts_t(model, thread_rngs) {}
};

} // namespace mcmc
242 changes: 242 additions & 0 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP
#define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP

#include <stan/math/prim.hpp>
#include <stan/callbacks/interrupt.hpp>
#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/writer.hpp>
#include <stan/io/var_context.hpp>
#include <stan/mcmc/fixed_param_sampler.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp>
#include <stan/services/sample/hmc_nuts_diag_e_adapt.hpp>
#include <stan/services/util/run_adaptive_sampler.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/inv_metric.hpp>
#include <vector>

namespace stan {
namespace services {
namespace sample {

/**
* Runs multiple chains of HMC with NUTS with adaptation using diagonal
* Euclidean metric with a pre-specified Euclidean metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam InitInvContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
must
* be the same length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] init_inv_metric An std vector of var contexts exposing an initial
diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id + num_chains - 1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_diag_e_adapt_parallel(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (tbb::this_task_arena::max_concurrency() == 1) {
std::cout << "Running serial" << std::endl;
return hmc_nuts_diag_e_adapt(
model, num_chains, init, init_inv_metric, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
}
const int num_threads = tbb::this_task_arena::max_concurrency();
std::vector<boost::ecuyer1988> rngs;
rngs.reserve(num_threads);
try {
for (int i = 0; i < num_threads; ++i) {
rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
}
} catch (const std::domain_error& e) {
return error_codes::CONFIG;
}
int ret_code = error_codes::OK;
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains,
init_chain_id, &ret_code, &model, &rngs, &interrupt, &logger,
&sample_writer, &init, &init_writer, &init_inv_metric, init_radius,
delta, stepsize, max_depth, stepsize_jitter, gamma, kappa, t0,
init_buffer, term_buffer, window,
&diagnostic_writer](const tbb::blocked_range<size_t>& r) {
boost::ecuyer1988& thread_rng
= rngs[tbb::this_task_arena::current_thread_index()];
using sample_t
= stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988, true>;
Eigen::VectorXd inv_metric;
std::vector<double> cont_vector;
for (size_t i = r.begin(); i != r.end(); ++i) {
sample_t sampler(model, rngs);
try {
cont_vector
= util::initialize(model, *init[i], thread_rng, init_radius,
true, logger, init_writer[i]);
inv_metric = util::read_diag_inv_metric(
*init_inv_metric[i], model.num_params_r(), logger);
util::validate_diag_inv_metric(inv_metric, logger);

sampler.set_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);

sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize));
sampler.get_stepsize_adaptation().set_delta(delta);
sampler.get_stepsize_adaptation().set_gamma(gamma);
sampler.get_stepsize_adaptation().set_kappa(kappa);
sampler.get_stepsize_adaptation().set_t0(t0);
sampler.set_window_params(num_warmup, init_buffer, term_buffer,
window, logger);
} catch (const std::domain_error& e) {
ret_code = error_codes::CONFIG;
return;
}
util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup,
num_samples, num_thin, refresh,
save_warmup, rngs[i], interrupt, logger,
sample_writer[i], diagnostic_writer[i],
init_chain_id + i, num_chains);
}
},
tbb::simple_partitioner());
return ret_code;
}

/**
* Runs multiple chains of HMC with NUTS with adaptation using diagonal
* Euclidean metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_diag_e_adapt_parallel(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (tbb::this_task_arena::max_concurrency() == 1) {
std::cout << "Running serial" << std::endl;
return hmc_nuts_diag_e_adapt(
model, num_chains, init, random_seed, init_chain_id, init_radius,
num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize,
stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer,
term_buffer, window, interrupt, logger, init_writer, sample_writer,
diagnostic_writer);
}
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
unit_e_metrics.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
unit_e_metrics.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_diag_inv_metric(model.num_params_r())));
}
return hmc_nuts_diag_e_adapt_parallel(
model, num_chains, init, unit_e_metrics, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
}

} // namespace sample
} // namespace services
} // namespace stan
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include <stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp>
#include <gtest/gtest.h>
#include <stan/io/empty_var_context.hpp>
#include <test/test-models/good/optimization/rosenbrock.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <iostream>

auto&& blah = stan::math::init_threadpool_tbb();

static constexpr size_t num_chains = 4;
class ServicesSampleHmcNutsDiagEAdaptPar : public testing::Test {
public:
ServicesSampleHmcNutsDiagEAdaptPar() : model(data_context, 0, &model_log) {
for (int i = 0; i < num_chains; ++i) {
init.push_back(stan::test::unit::instrumented_writer{});
parameter.push_back(stan::test::unit::instrumented_writer{});
diagnostic.push_back(stan::test::unit::instrumented_writer{});
context.push_back(std::make_shared<stan::io::empty_var_context>());
}
}
stan::io::empty_var_context data_context;
std::stringstream model_log;
stan::test::unit::instrumented_logger logger;
std::vector<stan::test::unit::instrumented_writer> init;
std::vector<stan::test::unit::instrumented_writer> parameter;
std::vector<stan::test::unit::instrumented_writer> diagnostic;
std::vector<std::shared_ptr<stan::io::empty_var_context>> context;
stan_model model;
};

TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, call_count) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init, parameter, diagnostic);

EXPECT_EQ(0, return_code);

int num_output_lines = (num_warmup + num_samples) / num_thin;
EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count());
for (int i = 0; i < num_chains; ++i) {
EXPECT_EQ(1, parameter[i].call_count("vector_string"));
EXPECT_EQ(num_output_lines, parameter[i].call_count("vector_double"));
EXPECT_EQ(1, diagnostic[i].call_count("vector_string"));
EXPECT_EQ(num_output_lines, diagnostic[i].call_count("vector_double"));
}
}

TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, parameter_checks) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init, parameter, diagnostic);

for (size_t i = 0; i < num_chains; ++i) {
std::vector<std::vector<std::string>> parameter_names;
parameter_names = parameter[i].vector_string_values();
std::vector<std::vector<double>> parameter_values;
parameter_values = parameter[i].vector_double_values();
std::vector<std::vector<std::string>> diagnostic_names;
diagnostic_names = diagnostic[i].vector_string_values();
std::vector<std::vector<double>> diagnostic_values;
diagnostic_values = diagnostic[i].vector_double_values();

// Expectations of parameter parameter names.
ASSERT_EQ(9, parameter_names[0].size());
EXPECT_EQ("lp__", parameter_names[0][0]);
EXPECT_EQ("accept_stat__", parameter_names[0][1]);
EXPECT_EQ("stepsize__", parameter_names[0][2]);
EXPECT_EQ("treedepth__", parameter_names[0][3]);
EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
EXPECT_EQ("divergent__", parameter_names[0][5]);
EXPECT_EQ("energy__", parameter_names[0][6]);
EXPECT_EQ("x", parameter_names[0][7]);
EXPECT_EQ("y", parameter_names[0][8]);

// Expect one name per parameter value.
EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());

EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size());

// Expect one call to set parameter names, and one set of output per
// iteration.
EXPECT_EQ("lp__", diagnostic_names[0][0]);
EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
}
EXPECT_EQ(return_code, 0);
}

TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, output_regression) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

stan::services::sample::hmc_nuts_diag_e_adapt_parallel(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init, parameter, diagnostic);

for (auto&& init_it : init) {
std::vector<std::string> init_values;
init_values = init_it.string_values();

EXPECT_EQ(0, init_values.size());
}

EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Total)"));
EXPECT_EQ(0, logger.call_count_error());
}