1
- /* *
2
- * @brief ビームサーチ
3
- * @docs docs/marathon/beam_search.md
4
- */
5
-
6
1
template <typename BeamSearchState, typename LiteBeamSearchState>
7
2
class BeamSearchBase {
8
3
public:
@@ -15,7 +10,7 @@ class BeamSearchBase {
15
10
const function<void (BeamSearchState &)> &add_next_lite_states,
16
11
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
17
12
const BeamSearchState &get_best_state () const {
18
- if (states_.empty ()) {
13
+ if (states_.empty ()) {
19
14
throw runtime_error (" No states are registered." );
20
15
}
21
16
return *min_element (states_.begin (), states_.end ());
@@ -45,14 +40,14 @@ class BeamSearch : public BeamSearchBase<BeamSearchState, LiteBeamSearchState> {
45
40
const function<void (BeamSearchState &)> &add_next_lite_states,
46
41
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
47
42
lite_states_.clear ();
48
- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
49
- this ->current_state_idx_ ++) {
43
+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
44
+ this ->current_state_idx_ ++) {
50
45
add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
51
46
}
52
47
const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
53
48
nth_element (lite_states_.begin (), lite_states_.begin () + num_select, lite_states_.end ());
54
49
vector<BeamSearchState> next_states (num_select);
55
- for (int i = 0 ; i < num_select; i++) {
50
+ for (int i = 0 ; i < num_select; i++) {
56
51
next_states[i] = to_next_state (lite_states_[i], this ->states_ [lite_states_[i].state_idx ]);
57
52
}
58
53
this ->states_ = move (next_states);
@@ -71,7 +66,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
71
66
state.state_idx = this ->current_state_idx_ ;
72
67
lite_states_.emplace (state);
73
68
// remove the worst (biggest) state.
74
- while ((int )lite_states_.size () > this ->beam_width_ ) {
69
+ while ((int )lite_states_.size () > this ->beam_width_ ) {
75
70
lite_states_.pop ();
76
71
}
77
72
}
@@ -82,17 +77,21 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
82
77
// NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
83
78
lite_states_ = priority_queue<LiteBeamSearchState>();
84
79
hash_values_.clear ();
85
- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
86
- this ->current_state_idx_ ++) {
80
+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
81
+ this ->current_state_idx_ ++) {
87
82
add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
88
83
}
89
84
const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
85
+ vector<LiteBeamSearchState> lite_states_vec (lite_states_.size ());
86
+ for (int i = (int )lite_states_vec.size () - 1 ; i >= 0 ; i--) {
87
+ lite_states_vec[i] = lite_states_.top ();
88
+ lite_states_.pop ();
89
+ }
90
90
vector<BeamSearchState> next_states (num_select);
91
91
int num_next_states = 0 ;
92
- while (lite_states_.size ()) {
93
- auto lite_state = lite_states_.top ();
94
- lite_states_.pop ();
95
- if (hash_values_.count (lite_state.hash_value )) {
92
+ for (int i = 0 ; i < (int )lite_states_vec.size (); i++) {
93
+ auto &lite_state = lite_states_vec[i];
94
+ if (hash_values_.count (lite_state.hash_value )) {
96
95
continue ;
97
96
}
98
97
hash_values_.insert (lite_state.hash_value );
@@ -111,6 +110,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
111
110
template <typename score_t >
112
111
struct BeamSearchStateBase {
113
112
score_t score;
113
+ ll hash_value;
114
114
BeamSearchStateBase () : score(0 ) {}
115
115
virtual bool operator <(const BeamSearchStateBase &state) const = 0 ;
116
116
};
0 commit comments