Skip to content

Commit 66b5e8e

Browse files
committed
add beam search
1 parent 67dcfb4 commit 66b5e8e

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

docs/marathon/beam_search.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- 状態 `BeamSearchState` と、軽量化した状態 `LiteBeamSearchState` の両方を定義して使う
2+
- 軽量化した状態は、参照する `BeamSearchState` のインデックス +α の情報を持たせる
3+
- 以下の 2 つの関数を定義してビームサーチを実行する
4+
- `add_next_lite_states`: 状態を受けとり、次の状態を軽量化したものを返す関数
5+
- `to_next_state`: 軽量化した状態を受けとり、それを元の (軽量化していない) 状態に変換する関数

marathon/beam_search.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* @brief ビームサーチ
3+
* @docs docs/marathon/beam_search.md
4+
*/
5+
6+
template <class BeamSearchState, class LiteBeamSearchState> class BeamSearch {
7+
public:
8+
BeamSearch() = default;
9+
void set_beam_width(int beam_width) { beam_width_ = beam_width; }
10+
void register_state(const BeamSearchState &state) {
11+
states_.emplace_back(state);
12+
}
13+
void register_lite_state(const LiteBeamSearchState &state) {
14+
lite_states_.emplace_back(state);
15+
lite_states_.back().state_idx = current_state_idx_;
16+
}
17+
void search(const auto &add_next_lite_states, const auto &to_next_state) {
18+
lite_states_.clear();
19+
for(current_state_idx_ = 0; current_state_idx_ < states_.size();
20+
current_state_idx_++) {
21+
add_next_lite_states(states_[current_state_idx_]);
22+
}
23+
const int num_select = min((int)lite_states_.size(), beam_width_);
24+
nth_element(lite_states_.begin(), lite_states_.begin() + num_select,
25+
lite_states_.end());
26+
vector<BeamSearchState> next_states(num_select);
27+
for(int i = 0; i < num_select; i++) {
28+
next_states[i] = to_next_state(lite_states_[i],
29+
states_[lite_states_[i].state_idx]);
30+
}
31+
states_ = move(next_states);
32+
}
33+
const BeamSearchState &get_best_state() const {
34+
return *min_element(states_.begin(), states_.end());
35+
}
36+
37+
private:
38+
int beam_width_;
39+
size_t current_state_idx_;
40+
vector<BeamSearchState> states_;
41+
vector<LiteBeamSearchState> lite_states_;
42+
};
43+
44+
struct LiteBeamSearchStateBase {
45+
int state_idx;
46+
LiteBeamSearchStateBase() : state_idx(-1) {}
47+
};

0 commit comments

Comments
 (0)