|
| 1 | +/** |
| 2 | + * @brief ビームサーチ |
| 3 | + * @docs docs/marathon/beam_search.md |
| 4 | + */ |
| 5 | + |
| 6 | +template <class BeamSearchState, class LiteBeamSearchState> class BeamSearch { |
| 7 | + public: |
| 8 | + BeamSearch() = default; |
| 9 | + void set_beam_width(int beam_width) { beam_width_ = beam_width; } |
| 10 | + void register_state(const BeamSearchState &state) { |
| 11 | + states_.emplace_back(state); |
| 12 | + } |
| 13 | + void register_lite_state(const LiteBeamSearchState &state) { |
| 14 | + lite_states_.emplace_back(state); |
| 15 | + lite_states_.back().state_idx = current_state_idx_; |
| 16 | + } |
| 17 | + void search(const auto &add_next_lite_states, const auto &to_next_state) { |
| 18 | + lite_states_.clear(); |
| 19 | + for(current_state_idx_ = 0; current_state_idx_ < states_.size(); |
| 20 | + current_state_idx_++) { |
| 21 | + add_next_lite_states(states_[current_state_idx_]); |
| 22 | + } |
| 23 | + const int num_select = min((int)lite_states_.size(), beam_width_); |
| 24 | + nth_element(lite_states_.begin(), lite_states_.begin() + num_select, |
| 25 | + lite_states_.end()); |
| 26 | + vector<BeamSearchState> next_states(num_select); |
| 27 | + for(int i = 0; i < num_select; i++) { |
| 28 | + next_states[i] = to_next_state(lite_states_[i], |
| 29 | + states_[lite_states_[i].state_idx]); |
| 30 | + } |
| 31 | + states_ = move(next_states); |
| 32 | + } |
| 33 | + const BeamSearchState &get_best_state() const { |
| 34 | + return *min_element(states_.begin(), states_.end()); |
| 35 | + } |
| 36 | + |
| 37 | + private: |
| 38 | + int beam_width_; |
| 39 | + size_t current_state_idx_; |
| 40 | + vector<BeamSearchState> states_; |
| 41 | + vector<LiteBeamSearchState> lite_states_; |
| 42 | +}; |
| 43 | + |
| 44 | +struct LiteBeamSearchStateBase { |
| 45 | + int state_idx; |
| 46 | + LiteBeamSearchStateBase() : state_idx(-1) {} |
| 47 | +}; |
0 commit comments