Skip to content

Commit c0da818

Browse files
committed
fix beam_search
1 parent 84529dc commit c0da818

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

docs/marathon/beam_search.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
---
2+
title: ビームサーチ
3+
documentation_of: ./marathon/beam_search.cpp
4+
---
5+
16
- `BeamSearch``BeamSearchWithHash` の 2 種類が利用可能
27
- 重複した状態を除去したいときは `BeamSearchWithHash` を使う。そうでないときは `BeamSearch` を使う。
38
- `BeamSearch` だと $\log$ が落ちるはずなので、こっちのほうがイテレーションは回る
@@ -9,4 +14,4 @@
914
- `to_next_state`: 軽量化した状態を受けとり、それを元の (軽量化していない) 状態に変換する関数
1015

1116
AHC040 で使用実績あり。
12-
https://atcoder.jp/contests/ahc040/submissions/61300845
17+
[https://atcoder.jp/contests/ahc040/submissions/61300845](https://atcoder.jp/contests/ahc040/submissions/61300845)

marathon/beam_search.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
/**
2-
* @brief ビームサーチ
3-
* @docs docs/marathon/beam_search.md
4-
*/
5-
61
template <typename BeamSearchState, typename LiteBeamSearchState>
72
class BeamSearchBase {
83
public:
@@ -15,7 +10,7 @@ class BeamSearchBase {
1510
const function<void(BeamSearchState &)> &add_next_lite_states,
1611
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
1712
const BeamSearchState &get_best_state() const {
18-
if(states_.empty()) {
13+
if (states_.empty()) {
1914
throw runtime_error("No states are registered.");
2015
}
2116
return *min_element(states_.begin(), states_.end());
@@ -45,14 +40,14 @@ class BeamSearch : public BeamSearchBase<BeamSearchState, LiteBeamSearchState> {
4540
const function<void(BeamSearchState &)> &add_next_lite_states,
4641
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
4742
lite_states_.clear();
48-
for(this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
49-
this->current_state_idx_++) {
43+
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
44+
this->current_state_idx_++) {
5045
add_next_lite_states(this->states_[this->current_state_idx_]);
5146
}
5247
const int num_select = min((int)lite_states_.size(), this->beam_width_);
5348
nth_element(lite_states_.begin(), lite_states_.begin() + num_select, lite_states_.end());
5449
vector<BeamSearchState> next_states(num_select);
55-
for(int i = 0; i < num_select; i++) {
50+
for (int i = 0; i < num_select; i++) {
5651
next_states[i] = to_next_state(lite_states_[i], this->states_[lite_states_[i].state_idx]);
5752
}
5853
this->states_ = move(next_states);
@@ -71,7 +66,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
7166
state.state_idx = this->current_state_idx_;
7267
lite_states_.emplace(state);
7368
// remove the worst (biggest) state.
74-
while((int)lite_states_.size() > this->beam_width_) {
69+
while ((int)lite_states_.size() > this->beam_width_) {
7570
lite_states_.pop();
7671
}
7772
}
@@ -82,17 +77,21 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
8277
// NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
8378
lite_states_ = priority_queue<LiteBeamSearchState>();
8479
hash_values_.clear();
85-
for(this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
86-
this->current_state_idx_++) {
80+
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
81+
this->current_state_idx_++) {
8782
add_next_lite_states(this->states_[this->current_state_idx_]);
8883
}
8984
const int num_select = min((int)lite_states_.size(), this->beam_width_);
85+
vector<LiteBeamSearchState> lite_states_vec(lite_states_.size());
86+
for (int i = (int)lite_states_vec.size() - 1; i >= 0; i--) {
87+
lite_states_vec[i] = lite_states_.top();
88+
lite_states_.pop();
89+
}
9090
vector<BeamSearchState> next_states(num_select);
9191
int num_next_states = 0;
92-
while(lite_states_.size()) {
93-
auto lite_state = lite_states_.top();
94-
lite_states_.pop();
95-
if(hash_values_.count(lite_state.hash_value)) {
92+
for (int i = 0; i < (int)lite_states_vec.size(); i++) {
93+
auto &lite_state = lite_states_vec[i];
94+
if (hash_values_.count(lite_state.hash_value)) {
9695
continue;
9796
}
9897
hash_values_.insert(lite_state.hash_value);
@@ -111,6 +110,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
111110
template <typename score_t>
112111
struct BeamSearchStateBase {
113112
score_t score;
113+
ll hash_value;
114114
BeamSearchStateBase() : score(0) {}
115115
virtual bool operator<(const BeamSearchStateBase &state) const = 0;
116116
};

0 commit comments

Comments
 (0)