diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..c64f970e637 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -12,18 +12,25 @@ namespace mcmc { * with a Gaussian-Euclidean disintegration and adaptive * diagonal metric and adaptive step size */ -template -class adapt_diag_e_nuts : public diag_e_nuts, +template +class adapt_diag_e_nuts : public diag_e_nuts, public stepsize_var_adapter { public: + template * = nullptr> adapt_diag_e_nuts(const Model& model, BaseRNG& rng) - : diag_e_nuts(model, rng), + : diag_e_nuts(model, rng), stepsize_var_adapter(model.num_params_r()) {} - ~adapt_diag_e_nuts() {} + template * = nullptr> + adapt_diag_e_nuts(const Model& model, std::vector& thread_rngs) + : diag_e_nuts(model, thread_rngs), + stepsize_var_adapter(model.num_params_r()) {} - sample transition(sample& init_sample, callbacks::logger& logger) { - sample s = diag_e_nuts::transition(init_sample, logger); + inline sample transition(sample& init_sample, callbacks::logger& logger) { + sample s = diag_e_nuts::transition( + init_sample, logger); if (this->adapt_flag_) { this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, diff --git a/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp new file mode 100644 index 00000000000..2c3f486ad93 --- /dev/null +++ b/src/stan/mcmc/hmc/nuts/base_parallel_nuts.hpp @@ -0,0 +1,780 @@ +#ifndef STAN_MCMC_HMC_NUTS_BASE_PARALLEL_NUTS_HPP +#define STAN_MCMC_HMC_NUTS_BASE_PARALLEL_NUTS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +// Prototype of speculative NUTS. +// Uses the Intel Flow Graph concept to turn NUTS into a parallel +// algorithm in that the forward and backward sweep run at the same +// time in parallel. + +namespace stan { +namespace mcmc { + +template +inline auto make_uniform_vec(std::vector& thread_rngs) { + /* + std::vector> rand_uniform_vec; + const size_t num_thread_rngs = thread_rngs.size(); + rand_uniform_vec.reserve(num_thread_rngs); + for (size_t i = 0; i < rand_uniform_vec.size(); ++i) { + rand_uniform_vec.emplace_back(thread_rngs[i]); + } + */ + return std::vector>(thread_rngs.begin(), + thread_rngs.end()); +} + +/** + * The No-U-Turn sampler (NUTS) with multinomial sampling + */ +template class Hamiltonian, + template class Integrator, class BaseRNG> +class base_parallel_nuts + : public base_hmc { + public: + using state_t = typename Hamiltonian::PointType; + + base_parallel_nuts(const Model& model, std::vector& thread_rngs) + : base_hmc( + model, thread_rngs[tbb::this_task_arena::current_thread_index()]), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + base_parallel_nuts(const Model& model, BaseRNG& rng, + std::vector& thread_rngs) + : base_hmc(model, rng), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + /** + * specialized constructor for specified diag mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::VectorXd& inv_e_metric, + std::vector& thread_rngs) + : base_hmc(model, rng, + inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + /** + * specialized constructor for specified dense mass matrix + */ + base_parallel_nuts(const Model& model, BaseRNG& rng, + Eigen::MatrixXd& inv_e_metric, + std::vector& thread_rngs) + : base_hmc(model, rng, + inv_e_metric), + rand_uniform_vec_(make_uniform_vec(thread_rngs)) {} + + ~base_parallel_nuts() {} + + inline void set_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + inline void set_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_metric(inv_e_metric); + } + + inline void set_max_depth(int d) noexcept { + if (d > 0) { + max_depth_ = d; + } + } + + inline void set_max_delta(double d) noexcept { max_deltaH_ = d; } + + inline int get_max_depth() noexcept { return this->max_depth_; } + inline double get_max_delta() noexcept { return this->max_deltaH_; } + + // stores from left/right subtree entire information + struct subtree { + subtree(const double sign, const ps_point& z_end, + const Eigen::VectorXd& p_sharp_end, double H0) + : z_end_(z_end), + z_propose_(z_end), + p_sharp_end_(p_sharp_end), + H0_(H0), + sign_(sign), + n_leapfrog_(0), + sum_metro_prob_(0) {} + + ps_point z_end_; + ps_point z_propose_; + Eigen::VectorXd p_sharp_end_; + double H0_; + double sign_; + int n_leapfrog_{0}; + double sum_metro_prob_{0}; + }; + + // extends the tree into the direction of the sign of the + // subtree + using extend_tree_t = std::tuple; + + inline extend_tree_t extend_tree(int depth, subtree& tree, state_t& z, + callbacks::logger& logger) { + // save the current ends needed for later criterion computations + // Eigen::VectorXd p_end = tree.p_end_; + // Eigen::VectorXd p_sharp_end = tree.p_sharp_end_; + Eigen::VectorXd p_sharp_dummy + = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + + Eigen::VectorXd rho_subtree + = Eigen::VectorXd::Zero(tree.p_sharp_end_.size()); + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + tree.n_leapfrog_ = 0; + tree.sum_metro_prob_ = 0; + + z.ps_point::operator=(tree.z_end_); + + bool valid_subtree = build_tree( + depth, z, tree.z_propose_, p_sharp_dummy, tree.p_sharp_end_, + rho_subtree, tree.H0_, tree.sign_, tree.n_leapfrog_, + log_sum_weight_subtree, tree.sum_metro_prob_, logger); + + tree.z_end_.ps_point::operator=(z); + + return std::make_tuple(valid_subtree, log_sum_weight_subtree, rho_subtree, + tree.p_sharp_end_, tree.z_propose_, tree.n_leapfrog_, + tree.sum_metro_prob_); + } + + inline sample transition(sample& init_sample, callbacks::logger& logger) { + return transition_parallel(init_sample, logger); + } + + // this implementation builds up the dependence graph every call + // to transition. Things which should be refactored: + // 1. build up the nodes only once + // 2. add a prepare method to each node which samples its + // direction and needed random numbers for multinomial sampling + // 3. only the edges are added dynamically. So the forward nodes + // are wired-up and the backward nodes are wired-up if run + // parallel. If run serially, then each grow node is alternated + // with a check node. + sample transition_parallel(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + // ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + // int n_leapfrog = 0; + // double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // actual states which move... copy construct atm...revise?! + state_t z_fwd(this->z_); + state_t z_bck(this->z_); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + this->valid_trees_ = true; + + // the actual number of leapfrog steps in trajectory used + // excluding the ones executed speculative + int n_leapfrog = 0; + + // actually summed metropolis prob of used trajectory + double sum_metro_prob = 0; + + std::vector fwd_direction(this->max_depth_); + + for (std::size_t i = 0; i != this->max_depth_; ++i) + fwd_direction[i] = this->rand_uniform_() > 0.5; + + const std::size_t num_fwd + = std::accumulate(fwd_direction.begin(), fwd_direction.end(), 0); + const std::size_t num_bck = this->max_depth_ - num_fwd; + + /* + std::cout << "sampled turns: "; + for (std::size_t i = 0; i != this->max_depth_; ++i) { + if(fwd_direction[i]) + std::cout << "+,"; + else + std::cout << "-,"; + } + std::cout << std::endl; + */ + + tbb::concurrent_vector ends( + this->max_depth_, std::make_tuple(true, 0, Eigen::VectorXd(), + Eigen::VectorXd(), z_sample, 0, 0.0)); + tbb::concurrent_vector valid_subtree_fwd(num_fwd, true); + tbb::concurrent_vector valid_subtree_bck(num_bck, true); + + // HACK!!! + /* + callbacks::logger logger_fwd; + callbacks::logger logger_bck; + */ + // build TBB flow graph + tbb::flow::graph g; + + // add nodes which advance the left/right tree + using tree_builder_t = tbb::flow::continue_node; + + tbb::concurrent_vector all_builder_idx(this->max_depth_); + tbb::concurrent_vector fwd_builder; + fwd_builder.reserve(this->max_depth_); + tbb::concurrent_vector bck_builder; + bck_builder.reserve(this->max_depth_); + using builder_iter_t = tbb::concurrent_vector::iterator; + + // now wire up the fwd and bck build of the trees which + // depends on single-core or multi-core run + + // TODO: the extenders should also check for a global flag if + // we want to keep running + // TODO: We should also just run depth = 0 outside the loop to avoid the + // if statement here + for (std::size_t depth = 0, fwd_idx = 0, bck_idx = 0; + depth != this->max_depth_; ++depth) { + if (fwd_direction[depth]) { + builder_iter_t fwd_iter = fwd_builder.emplace_back( + g, [&, depth, fwd_idx](tbb::flow::continue_msg) { + // std::cout << "fwd turn at depth " << depth; + bool valid_parent + = fwd_idx == 0 ? true : valid_subtree_fwd[fwd_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_fwd, z_fwd, logger); + valid_subtree_fwd[fwd_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_fwd[fwd_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); + if (fwd_idx != 0) { + // in this case this is not the starting node, we + // connect this with its predecessor + tbb::flow::make_edge(*(fwd_iter - 1), *fwd_iter); + } + all_builder_idx[depth] = fwd_idx; + ++fwd_idx; + } else { + builder_iter_t bck_iter = bck_builder.emplace_back( + g, [&, depth, bck_idx](tbb::flow::continue_msg) { + // std::cout << "bck turn at depth " << depth; + bool valid_parent + = bck_idx == 0 ? true : valid_subtree_bck[bck_idx - 1]; + if (valid_parent) { + // std::cout << " yes, here we go!" << std::endl; + ends[depth] = extend_tree(depth, tree_bck, z_bck, logger); + valid_subtree_bck[bck_idx] = std::get<0>(ends[depth]); + } else { + valid_subtree_bck[bck_idx] = false; + } + // std::cout << " nothing to do." << std::endl; + }); + if (bck_idx != 0) { + // in case this is not the starting node, we connect + // this with his predecessor + // tbb::flow::make_edge(bck_builder[bck_idx-1], bck_builder[bck_idx]); + tbb::flow::make_edge(*(bck_iter - 1), *bck_iter); + } + all_builder_idx[depth] = bck_idx; + ++bck_idx; + } + } + + // finally wire in the checker which accepts or rejects the + // proposed states from the subtrees + // typedef function_node< tbb::flow::tuple, bool> checker_t; + // typedef join_node< tbb::flow::tuple > joiner_t; + using checker_t = tbb::flow::continue_node; + + tbb::concurrent_vector checks; + // std::vector joins; + + Eigen::VectorXd p_sharp_fwd(p_sharp); + Eigen::VectorXd p_sharp_bck(p_sharp); + + for (std::size_t depth = 0; depth != this->max_depth_; ++depth) { + // joins.push_back(joiner_t(g)); + // std::cout << "creating check at depth " << depth << std::endl; + checks.emplace_back(g, [&, depth](tbb::flow::continue_msg) { + const bool is_fwd = fwd_direction[depth]; + + extend_tree_t& subtree_result = ends[depth]; + + // if we are still on the + // trajectories which are + // actually used update the + // running tree stats + if (this->valid_trees_) { + this->depth_ = depth + 1; + n_leapfrog += std::get<5>(subtree_result); + sum_metro_prob += std::get<6>(subtree_result); + } + + const bool valid_subtree + = is_fwd ? valid_subtree_fwd[all_builder_idx[depth]] + : valid_subtree_bck[all_builder_idx[depth]]; + + const bool is_valid = valid_subtree && this->valid_trees_; + + // std::cout << "CHECK at depth " << depth; + + if (!is_valid) { + // std::cout << " we are done (early)" << std::endl; + + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + return; + } + + // std::cout << " checking" << std::endl; + + double log_sum_weight_subtree = std::get<1>(subtree_result); + const Eigen::VectorXd& rho_subtree = std::get<2>(subtree_result); + + // update correct side + if (is_fwd) { + p_sharp_fwd = std::get<3>(subtree_result); + } else { + p_sharp_bck = std::get<3>(subtree_result); + } + + const ps_point& z_propose = std::get<4>(subtree_result); + + // update running sums + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob + = std::exp(log_sum_weight_subtree - log_sum_weight); + // if (this->rand_uniform_() < + // accept_prob) + // HACK + if (get_rand_uniform() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_bck, p_sharp_fwd, rho)) { + // setting this globally here + // will terminate all ongoing work + this->valid_trees_ = false; + // std::cout << " we are done (later)" << std::endl; + } + // std::cout << " continuing (later)" << std::endl; + }); + if (fwd_direction[depth]) { + // std::cout << "depth " << depth << ": joining fwd node " << + // all_builder_idx[depth] << " into join node." << std::endl; + tbb::flow::make_edge(fwd_builder[all_builder_idx[depth]], + checks.back()); + } else { + // std::cout << "depth " << depth << ": joining bck node " << + // all_builder_idx[depth] << " into join node." << std::endl; + tbb::flow::make_edge(bck_builder[all_builder_idx[depth]], + checks.back()); + } + if (depth != 0) { + tbb::flow::make_edge(checks[depth - 1], checks.back()); + } + } + + // kick off work + if (fwd_direction[0]) { + fwd_builder[0].try_put(tbb::flow::continue_msg()); + // the first turn is fwd, so kick off the bck walker if needed + if (num_bck != 0) { + bck_builder[0].try_put(tbb::flow::continue_msg()); + } + } else { + bck_builder[0].try_put(tbb::flow::continue_msg()); + if (num_fwd != 0) { + fwd_builder[0].try_put(tbb::flow::continue_msg()); + } + } + + g.wait_for_all(); + + this->n_leapfrog_ = n_leapfrog; + // this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + // this includes the speculative executed ones + // const double sum_metro_prob = tree_fwd.sum_metro_prob_ + + // tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample transition_refactored(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + const ps_point z_init(this->z_); + + ps_point z_sample(z_init); + ps_point z_propose(z_init); + + const Eigen::VectorXd p_sharp = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + // int n_leapfrog = 0; + // double sum_metro_prob = 0; + + // forward tree + subtree tree_fwd(1, z_init, p_sharp, H0); + // backward tree + subtree tree_bck(-1, z_init, p_sharp, H0); + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + bool valid_subtree; + double log_sum_weight_subtree; + Eigen::VectorXd rho_subtree; + + if (this->rand_uniform_() > 0.5) { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_fwd, this->z_, logger); + } else { + std::tie(valid_subtree, log_sum_weight_subtree, rho_subtree, z_propose) + = extend_tree(this->depth_, tree_bck, this->z_, logger); + } + + if (!valid_subtree) + break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(tree_bck.p_sharp_end_, tree_fwd.p_sharp_end_, rho)) + break; + // if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // break; + } + + // this->n_leapfrog_ = n_leapfrog; + this->n_leapfrog_ = tree_fwd.n_leapfrog_ + tree_bck.n_leapfrog_; + + const double sum_metro_prob + = tree_fwd.sum_metro_prob_ + tree_bck.sum_metro_prob_; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob + = sum_metro_prob / static_cast(this->n_leapfrog_); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + sample transition_old(sample& init_sample, callbacks::logger& logger) { + // Initialize the algorithm + this->sample_stepsize(); + + this->seed(init_sample.cont_params()); + + this->hamiltonian_.sample_p(this->z_, this->rand_int_); + this->hamiltonian_.init(this->z_, logger); + + ps_point z_plus(this->z_); + ps_point z_minus(z_plus); + + ps_point z_sample(z_plus); + ps_point z_propose(z_plus); + + Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); + // Eigen::VectorXd p_sharp_dummy = p_sharp_plus; + Eigen::VectorXd p_sharp_minus = p_sharp_plus; + Eigen::VectorXd rho = this->z_.p; + + double log_sum_weight = 0; // log(exp(H0 - H0)) + double H0 = this->hamiltonian_.H(this->z_); + int n_leapfrog = 0; + double sum_metro_prob = 0; + + // Build a trajectory until the NUTS criterion is no longer satisfied + this->depth_ = 0; + this->divergent_ = false; + + while (this->depth_ < this->max_depth_) { + // Build a new subtree in a random direction + Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; + double log_sum_weight_subtree = -std::numeric_limits::infinity(); + + // this should be fine (modified from orig) + Eigen::VectorXd p_sharp_dummy = Eigen::VectorXd::Zero(this->z_.p.size()); + + if (this->rand_uniform_() > 0.5) { + this->z_.ps_point::operator=(z_plus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, p_sharp_dummy, + p_sharp_plus, rho_subtree, H0, 1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, logger); + z_plus.ps_point::operator=(this->z_); + } else { + this->z_.ps_point::operator=(z_minus); + valid_subtree + = build_tree(this->depth_, this->z_, z_propose, p_sharp_dummy, + p_sharp_minus, rho_subtree, H0, -1, n_leapfrog, + log_sum_weight_subtree, sum_metro_prob, logger); + z_minus.ps_point::operator=(this->z_); + } + + if (!valid_subtree) + break; + + // Sample from an accepted subtree + ++(this->depth_); + + if (log_sum_weight_subtree > log_sum_weight) { + z_sample = z_propose; + } else { + double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + if (this->rand_uniform_() < accept_prob) + z_sample = z_propose; + } + + log_sum_weight + = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + // Break when NUTS criterion is no longer satisfied + rho += rho_subtree; + if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + break; + } + + this->n_leapfrog_ = n_leapfrog; + + // Compute average acceptance probabilty across entire trajectory, + // even over subtrees that may have been rejected + double accept_prob = sum_metro_prob / static_cast(n_leapfrog); + + this->z_.ps_point::operator=(z_sample); + this->energy_ = this->hamiltonian_.H(this->z_); + return sample(this->z_.q, -this->z_.V, accept_prob); + } + + void get_sampler_param_names(std::vector& names) { + names.push_back("stepsize__"); + names.push_back("treedepth__"); + names.push_back("n_leapfrog__"); + names.push_back("divergent__"); + names.push_back("energy__"); + } + + void get_sampler_params(std::vector& values) { + values.push_back(this->epsilon_); + values.push_back(this->depth_); + values.push_back(this->n_leapfrog_); + values.push_back(this->divergent_); + values.push_back(this->energy_); + } + + virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return p_sharp_plus.dot(rho) > 0 && p_sharp_minus.dot(rho) > 0; + } + + /** + * Recursively build a new subtree to completion or until + * the subtree becomes invalid. Returns validity of the + * resulting subtree. + * + * @param depth Depth of the desired subtree + * @param z_beg State beginning from subtree + * @param z_propose State proposed from subtree + * @param p_sharp_left p_sharp from left boundary of returned tree + * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param rho Summed momentum across trajectory + * @param H0 Hamiltonian of initial state + * @param sign Direction in time to built subtree + * @param n_leapfrog Summed number of leapfrog evaluations + * @param log_sum_weight Log of summed weights across trajectory + * @param sum_metro_prob Summed Metropolis probabilities across trajectory + * @param logger Logger for messages + */ + bool build_tree(int depth, state_t& z_beg, ps_point& z_propose, + Eigen::VectorXd& p_sharp_left, Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& rho, double H0, double sign, int& n_leapfrog, + double& log_sum_weight, double& sum_metro_prob, + callbacks::logger& logger) { + // Base case + if (depth == 0) { + // check if trees are still valid or if we should terminate + if (!this->valid_trees_) + return false; + + this->integrator_.evolve(z_beg, this->hamiltonian_, sign * this->epsilon_, + logger); + + ++n_leapfrog; + + double h = this->hamiltonian_.H(z_beg); + if (boost::math::isnan(h)) + h = std::numeric_limits::infinity(); + + // TODO: in parallel case we cannot use the global divergent + // flag since this could be a speculative tree!! + // if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + bool is_divergent = (h - H0) > this->max_deltaH_; + // if ((h - H0) > this->max_deltaH_) this->divergent_ = true; + + log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h); + + if (H0 - h > 0) + sum_metro_prob += 1; + else + sum_metro_prob += std::exp(H0 - h); + + z_propose = z_beg; + rho += z_beg.p; + + p_sharp_left = this->hamiltonian_.dtau_dp(z_beg); + p_sharp_right = p_sharp_left; + + return !is_divergent; + } + // General recursion + Eigen::VectorXd p_sharp_dummy(z_beg.p.size()); + + // Build the left subtree + double log_sum_weight_left = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + + bool valid_left = build_tree(depth - 1, z_beg, z_propose, p_sharp_left, + p_sharp_dummy, rho_left, H0, sign, n_leapfrog, + log_sum_weight_left, sum_metro_prob, logger); + + if (!valid_left) + return false; + + // Build the right subtree + ps_point z_propose_right(z_beg); + + double log_sum_weight_right = -std::numeric_limits::infinity(); + Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + + bool valid_right + = build_tree(depth - 1, z_beg, z_propose_right, p_sharp_dummy, + p_sharp_right, rho_right, H0, sign, n_leapfrog, + log_sum_weight_right, sum_metro_prob, logger); + + if (!valid_right) + return false; + + // Multinomial sample from right subtree + double log_sum_weight_subtree + = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); + + if (log_sum_weight_right > log_sum_weight_subtree) { + z_propose = z_propose_right; + } else { + double accept_prob + = std::exp(log_sum_weight_right - log_sum_weight_subtree); + // if (this->rand_uniform_() < accept_prob) + if (get_rand_uniform() < accept_prob) + z_propose = z_propose_right; + } + + Eigen::VectorXd rho_subtree = rho_left + rho_right; + rho += rho_subtree; + + return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + } + + inline double get_rand_uniform() { + return this + ->rand_uniform_vec_[tbb::this_task_arena::current_thread_index()](); + } + + int depth_{0}; + int max_depth_{5}; + double max_deltaH_{1000}; + int n_leapfrog_{0}; + double energy_{0}; + bool valid_trees_{true}; + bool divergent_{false}; + // Uniform(0, 1) RNG + std::vector> rand_uniform_vec_; +}; + +template class Hamiltonian, + template class Integrator, class BaseRNG> +using base_nuts_ct = std::conditional_t< + ParallelBase, base_parallel_nuts, + base_nuts>; + +} // namespace mcmc +} // namespace stan +#endif diff --git a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp index 5f830f85cc6..241283f1f8f 100644 --- a/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/diag_e_nuts.hpp @@ -2,22 +2,32 @@ #define STAN_MCMC_HMC_NUTS_DIAG_E_NUTS_HPP #include +#include #include #include #include namespace stan { namespace mcmc { + /** * The No-U-Turn sampler (NUTS) with multinomial sampling * with a Gaussian-Euclidean disintegration and diagonal metric */ -template -class diag_e_nuts - : public base_nuts { +template +class diag_e_nuts : public base_nuts_ct { + using base_nuts_t = base_nuts_ct; + public: - diag_e_nuts(const Model& model, BaseRNG& rng) - : base_nuts(model, rng) {} + template * = nullptr> + diag_e_nuts(const Model& model, BaseRNG& rng) : base_nuts_t(model, rng) {} + template * = nullptr> + diag_e_nuts(const Model& model, std::vector& thread_rngs) + : base_nuts_t(model, thread_rngs) {} }; } // namespace mcmc diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp new file mode 100644 index 00000000000..3ebda83bd14 --- /dev/null +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt_parallel.hpp @@ -0,0 +1,242 @@ +#ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP +#define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_PARALLEL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace sample { + +/** + * Runs multiple chains of HMC with NUTS with adaptation using diagonal + * Euclidean metric with a pre-specified Euclidean metric. + * + * @tparam Model Model class + * @tparam InitContextPtr A pointer with underlying type derived from + `stan::io::var_context` + * @tparam InitInvContextPtr A pointer with underlying type derived from + `stan::io::var_context` + * @tparam SamplerWriter A type derived from `stan::callbacks::writer` + * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` + * @tparam InitWriter A type derived from `stan::callbacks::writer` + * @param[in] model Input model to test (with data already instantiated) + * @param[in] num_chains The number of chains to run in parallel. `init`, + * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` + must + * be the same length as this value. + * @param[in] init An std vector of init var contexts for initialization of each + * chain. + * @param[in] init_inv_metric An std vector of var contexts exposing an initial + diagonal inverse Euclidean metric for each chain (must be positive definite) + * @param[in] random_seed random seed for the random number generator + * @param[in] init_chain_id first chain id. The pseudo random number generator + * will advance for each chain by an integer sequence from `init_chain_id` to + * `init_chain_id + num_chains - 1` + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer std vector of Writer callbacks for unconstrained + * inits of each chain. + * @param[in,out] sample_writer std vector of Writers for draws of each chain. + * @param[in,out] diagnostic_writer std vector of Writers for diagnostic + * information of each chain. + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, size_t num_chains, const std::vector& init, + const std::vector& init_inv_metric, + unsigned int random_seed, unsigned int init_chain_id, double init_radius, + int num_warmup, int num_samples, int num_thin, bool save_warmup, + int refresh, double stepsize, double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, double t0, + unsigned int init_buffer, unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, callbacks::logger& logger, + std::vector& init_writer, + std::vector& sample_writer, + std::vector& diagnostic_writer) { + if (tbb::this_task_arena::max_concurrency() == 1) { + std::cout << "Running serial" << std::endl; + return hmc_nuts_diag_e_adapt( + model, num_chains, init, init_inv_metric, random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); + } + const int num_threads = tbb::this_task_arena::max_concurrency(); + std::vector rngs; + rngs.reserve(num_threads); + try { + for (int i = 0; i < num_threads; ++i) { + rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); + } + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + int ret_code = error_codes::OK; + tbb::parallel_for( + tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, + init_chain_id, &ret_code, &model, &rngs, &interrupt, &logger, + &sample_writer, &init, &init_writer, &init_inv_metric, init_radius, + delta, stepsize, max_depth, stepsize_jitter, gamma, kappa, t0, + init_buffer, term_buffer, window, + &diagnostic_writer](const tbb::blocked_range& r) { + boost::ecuyer1988& thread_rng + = rngs[tbb::this_task_arena::current_thread_index()]; + using sample_t + = stan::mcmc::adapt_diag_e_nuts; + Eigen::VectorXd inv_metric; + std::vector cont_vector; + for (size_t i = r.begin(); i != r.end(); ++i) { + sample_t sampler(model, rngs); + try { + cont_vector + = util::initialize(model, *init[i], thread_rng, init_radius, + true, logger, init_writer[i]); + inv_metric = util::read_diag_inv_metric( + *init_inv_metric[i], model.num_params_r(), logger); + util::validate_diag_inv_metric(inv_metric, logger); + + sampler.set_metric(inv_metric); + sampler.set_nominal_stepsize(stepsize); + sampler.set_stepsize_jitter(stepsize_jitter); + sampler.set_max_depth(max_depth); + + sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize)); + sampler.get_stepsize_adaptation().set_delta(delta); + sampler.get_stepsize_adaptation().set_gamma(gamma); + sampler.get_stepsize_adaptation().set_kappa(kappa); + sampler.get_stepsize_adaptation().set_t0(t0); + sampler.set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); + } catch (const std::domain_error& e) { + ret_code = error_codes::CONFIG; + return; + } + util::run_adaptive_sampler(sampler, model, cont_vector, num_warmup, + num_samples, num_thin, refresh, + save_warmup, rngs[i], interrupt, logger, + sample_writer[i], diagnostic_writer[i], + init_chain_id + i, num_chains); + } + }, + tbb::simple_partitioner()); + return ret_code; +} + +/** + * Runs multiple chains of HMC with NUTS with adaptation using diagonal + * Euclidean metric. + * + * @tparam Model Model class + * @tparam InitContextPtr A pointer with underlying type derived from + * `stan::io::var_context` + * @tparam SamplerWriter A type derived from `stan::callbacks::writer` + * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer` + * @tparam InitWriter A type derived from `stan::callbacks::writer` + * @param[in] model Input model to test (with data already instantiated) + * @param[in] num_chains The number of chains to run in parallel. `init`, + * `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same + * length as this value. + * @param[in] init An std vector of init var contexts for initialization of each + * chain. + * @param[in] random_seed random seed for the random number generator + * @param[in] init_chain_id first chain id. The pseudo random number generator + * will advance by for each chain by an integer sequence from `init_chain_id` to + * `init_chain_id+num_chains-1` + * @param[in] init_radius radius to initialize + * @param[in] num_warmup Number of warmup samples + * @param[in] num_samples Number of samples + * @param[in] num_thin Number to thin the samples + * @param[in] save_warmup Indicates whether to save the warmup iterations + * @param[in] refresh Controls the output + * @param[in] stepsize initial stepsize for discrete evolution + * @param[in] stepsize_jitter uniform random jitter of stepsize + * @param[in] max_depth Maximum tree depth + * @param[in] delta adaptation target acceptance statistic + * @param[in] gamma adaptation regularization scale + * @param[in] kappa adaptation relaxation exponent + * @param[in] t0 adaptation iteration offset + * @param[in] init_buffer width of initial fast adaptation interval + * @param[in] term_buffer width of final fast adaptation interval + * @param[in] window initial width of slow adaptation interval + * @param[in,out] interrupt Callback for interrupts + * @param[in,out] logger Logger for messages + * @param[in,out] init_writer std vector of Writer callbacks for unconstrained + * inits of each chain. + * @param[in,out] sample_writer std vector of Writers for draws of each chain. + * @param[in,out] diagnostic_writer std vector of Writers for diagnostic + * information of each chain. + * @return error_codes::OK if successful + */ +template +int hmc_nuts_diag_e_adapt_parallel( + Model& model, size_t num_chains, const std::vector& init, + unsigned int random_seed, unsigned int init_chain_id, double init_radius, + int num_warmup, int num_samples, int num_thin, bool save_warmup, + int refresh, double stepsize, double stepsize_jitter, int max_depth, + double delta, double gamma, double kappa, double t0, + unsigned int init_buffer, unsigned int term_buffer, unsigned int window, + callbacks::interrupt& interrupt, callbacks::logger& logger, + std::vector& init_writer, + std::vector& sample_writer, + std::vector& diagnostic_writer) { + if (tbb::this_task_arena::max_concurrency() == 1) { + std::cout << "Running serial" << std::endl; + return hmc_nuts_diag_e_adapt( + model, num_chains, init, random_seed, init_chain_id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, + term_buffer, window, interrupt, logger, init_writer, sample_writer, + diagnostic_writer); + } + std::vector> unit_e_metrics; + unit_e_metrics.reserve(num_chains); + for (size_t i = 0; i < num_chains; ++i) { + unit_e_metrics.emplace_back(std::make_unique( + util::create_unit_e_diag_inv_metric(model.num_params_r()))); + } + return hmc_nuts_diag_e_adapt_parallel( + model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); +} + +} // namespace sample +} // namespace services +} // namespace stan +#endif diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp new file mode 100644 index 00000000000..b11c2889bed --- /dev/null +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_parallel_test.cpp @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include +#include + +auto&& blah = stan::math::init_threadpool_tbb(); + +static constexpr size_t num_chains = 4; +class ServicesSampleHmcNutsDiagEAdaptPar : public testing::Test { + public: + ServicesSampleHmcNutsDiagEAdaptPar() : model(data_context, 0, &model_log) { + for (int i = 0; i < num_chains; ++i) { + init.push_back(stan::test::unit::instrumented_writer{}); + parameter.push_back(stan::test::unit::instrumented_writer{}); + diagnostic.push_back(stan::test::unit::instrumented_writer{}); + context.push_back(std::make_shared()); + } + } + stan::io::empty_var_context data_context; + std::stringstream model_log; + stan::test::unit::instrumented_logger logger; + std::vector init; + std::vector parameter; + std::vector diagnostic; + std::vector> context; + stan_model model; +}; + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, call_count) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + EXPECT_EQ(0, return_code); + + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count()); + for (int i = 0; i < num_chains; ++i) { + EXPECT_EQ(1, parameter[i].call_count("vector_string")); + EXPECT_EQ(num_output_lines, parameter[i].call_count("vector_double")); + EXPECT_EQ(1, diagnostic[i].call_count("vector_string")); + EXPECT_EQ(num_output_lines, diagnostic[i].call_count("vector_double")); + } +} + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, parameter_checks) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + int return_code = stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + for (size_t i = 0; i < num_chains; ++i) { + std::vector> parameter_names; + parameter_names = parameter[i].vector_string_values(); + std::vector> parameter_values; + parameter_values = parameter[i].vector_double_values(); + std::vector> diagnostic_names; + diagnostic_names = diagnostic[i].vector_string_values(); + std::vector> diagnostic_values; + diagnostic_values = diagnostic[i].vector_double_values(); + + // Expectations of parameter parameter names. + ASSERT_EQ(9, parameter_names[0].size()); + EXPECT_EQ("lp__", parameter_names[0][0]); + EXPECT_EQ("accept_stat__", parameter_names[0][1]); + EXPECT_EQ("stepsize__", parameter_names[0][2]); + EXPECT_EQ("treedepth__", parameter_names[0][3]); + EXPECT_EQ("n_leapfrog__", parameter_names[0][4]); + EXPECT_EQ("divergent__", parameter_names[0][5]); + EXPECT_EQ("energy__", parameter_names[0][6]); + EXPECT_EQ("x", parameter_names[0][7]); + EXPECT_EQ("y", parameter_names[0][8]); + + // Expect one name per parameter value. + EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size()); + EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size()); + + EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size()); + + // Expect one call to set parameter names, and one set of output per + // iteration. + EXPECT_EQ("lp__", diagnostic_names[0][0]); + EXPECT_EQ("accept_stat__", diagnostic_names[0][1]); + } + EXPECT_EQ(return_code, 0); +} + +TEST_F(ServicesSampleHmcNutsDiagEAdaptPar, output_regression) { + unsigned int random_seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int num_warmup = 200; + int num_samples = 400; + int num_thin = 5; + bool save_warmup = true; + int refresh = 0; + double stepsize = 0.1; + double stepsize_jitter = 0; + int max_depth = 8; + double delta = .1; + double gamma = .1; + double kappa = .1; + double t0 = .1; + unsigned int init_buffer = 50; + unsigned int term_buffer = 50; + unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + + stan::services::sample::hmc_nuts_diag_e_adapt_parallel( + model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, parameter, diagnostic); + + for (auto&& init_it : init) { + std::vector init_values; + init_values = init_it.string_values(); + + EXPECT_EQ(0, init_values.size()); + } + + EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)")); + EXPECT_EQ(num_chains, logger.find_info("seconds (Total)")); + EXPECT_EQ(0, logger.call_count_error()); +}