8
8
#include < assert.h>
9
9
#include < unordered_set>
10
10
#include < list>
11
+ #include < memory>
11
12
12
13
namespace hnswlib {
13
14
typedef unsigned int tableint;
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
33
34
double mult_{0.0 }, revSize_{0.0 };
34
35
int maxlevel_{0 };
35
36
36
- VisitedListPool * visited_list_pool_{nullptr };
37
+ std::unique_ptr< VisitedListPool> visited_list_pool_{nullptr };
37
38
38
39
// Locks operations with element by label value
39
40
mutable std::vector<std::mutex> label_op_locks_;
@@ -92,16 +93,22 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
92
93
size_t ef_construction = 200 ,
93
94
size_t random_seed = 100 ,
94
95
bool allow_replace_deleted = false )
95
- : link_list_locks_(max_elements ),
96
- label_op_locks_ (MAX_LABEL_OPERATION_LOCKS ),
96
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS ),
97
+ link_list_locks_ (max_elements ),
97
98
element_levels_(max_elements),
98
99
allow_replace_deleted_(allow_replace_deleted) {
99
100
max_elements_ = max_elements;
100
101
num_deleted_ = 0 ;
101
102
data_size_ = s->get_data_size ();
102
103
fstdistfunc_ = s->get_dist_func ();
103
104
dist_func_param_ = s->get_dist_func_param ();
104
- M_ = M;
105
+ if ( M <= 10000 ) {
106
+ M_ = M;
107
+ } else {
108
+ HNSWERR << " warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
109
+ HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
110
+ M_ = 10000 ;
111
+ }
105
112
maxM_ = M_;
106
113
maxM0_ = M_ * 2 ;
107
114
ef_construction_ = std::max (ef_construction, M_);
@@ -122,7 +129,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
122
129
123
130
cur_element_count = 0 ;
124
131
125
- visited_list_pool_ = new VisitedListPool (1 , max_elements);
132
+ visited_list_pool_ = std::unique_ptr<VisitedListPool>( new VisitedListPool (1 , max_elements) );
126
133
127
134
// initializations for special treatment of the first node
128
135
enterpoint_node_ = -1 ;
@@ -138,13 +145,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
138
145
139
146
140
147
~HierarchicalNSW () {
148
+ clear ();
149
+ }
150
+
151
+ void clear () {
141
152
free (data_level0_memory_);
153
+ data_level0_memory_ = nullptr ;
142
154
for (tableint i = 0 ; i < cur_element_count; i++) {
143
155
if (element_levels_[i] > 0 )
144
156
free (linkLists_[i]);
145
157
}
146
158
free (linkLists_);
147
- delete visited_list_pool_;
159
+ linkLists_ = nullptr ;
160
+ cur_element_count = 0 ;
161
+ visited_list_pool_.reset (nullptr );
148
162
}
149
163
150
164
@@ -291,9 +305,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
291
305
}
292
306
293
307
294
- template <bool has_deletions, bool collect_metrics = false >
308
+ // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
309
+ template <bool bare_bone_search = true , bool collect_metrics = false >
295
310
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst>
296
- searchBaseLayerST (tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr ) const {
311
+ searchBaseLayerST (
312
+ tableint ep_id,
313
+ const void *data_point,
314
+ size_t ef,
315
+ BaseFilterFunctor* isIdAllowed = nullptr ,
316
+ BaseSearchStopCondition<dist_t >* stop_condition = nullptr ) const {
297
317
VisitedList *vl = visited_list_pool_->getFreeVisitedList ();
298
318
vl_type *visited_array = vl->mass ;
299
319
vl_type visited_array_tag = vl->curV ;
@@ -302,10 +322,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
302
322
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> candidate_set;
303
323
304
324
dist_t lowerBound;
305
- if ((!has_deletions || !isMarkedDeleted (ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel (ep_id)))) {
306
- dist_t dist = fstdistfunc_ (data_point, getDataByInternalId (ep_id), dist_func_param_);
325
+ if (bare_bone_search ||
326
+ (!isMarkedDeleted (ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel (ep_id))))) {
327
+ char * ep_data = getDataByInternalId (ep_id);
328
+ dist_t dist = fstdistfunc_ (data_point, ep_data, dist_func_param_);
307
329
lowerBound = dist;
308
330
top_candidates.emplace (dist, ep_id);
331
+ if (!bare_bone_search && stop_condition) {
332
+ stop_condition->add_point_to_result (getExternalLabel (ep_id), ep_data, dist);
333
+ }
309
334
candidate_set.emplace (-dist, ep_id);
310
335
} else {
311
336
lowerBound = std::numeric_limits<dist_t >::max ();
@@ -316,9 +341,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
316
341
317
342
while (!candidate_set.empty ()) {
318
343
std::pair<dist_t , tableint> current_node_pair = candidate_set.top ();
344
+ dist_t candidate_dist = -current_node_pair.first ;
319
345
320
- if ((-current_node_pair.first ) > lowerBound &&
321
- (top_candidates.size () == ef || (!isIdAllowed && !has_deletions))) {
346
+ bool flag_stop_search;
347
+ if (bare_bone_search) {
348
+ flag_stop_search = candidate_dist > lowerBound;
349
+ } else {
350
+ if (stop_condition) {
351
+ flag_stop_search = stop_condition->should_stop_search (candidate_dist, lowerBound);
352
+ } else {
353
+ flag_stop_search = candidate_dist > lowerBound && top_candidates.size () == ef;
354
+ }
355
+ }
356
+ if (flag_stop_search) {
322
357
break ;
323
358
}
324
359
candidate_set.pop ();
@@ -353,19 +388,45 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
353
388
char *currObj1 = (getDataByInternalId (candidate_id));
354
389
dist_t dist = fstdistfunc_ (data_point, currObj1, dist_func_param_);
355
390
356
- if (top_candidates.size () < ef || lowerBound > dist) {
391
+ bool flag_consider_candidate;
392
+ if (!bare_bone_search && stop_condition) {
393
+ flag_consider_candidate = stop_condition->should_consider_candidate (dist, lowerBound);
394
+ } else {
395
+ flag_consider_candidate = top_candidates.size () < ef || lowerBound > dist;
396
+ }
397
+
398
+ if (flag_consider_candidate) {
357
399
candidate_set.emplace (-dist, candidate_id);
358
400
#ifdef USE_SSE
359
401
_mm_prefetch (data_level0_memory_ + candidate_set.top ().second * size_data_per_element_ +
360
402
offsetLevel0_, // /////////
361
403
_MM_HINT_T0); // //////////////////////
362
404
#endif
363
405
364
- if ((!has_deletions || !isMarkedDeleted (candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel (candidate_id))))
406
+ if (bare_bone_search ||
407
+ (!isMarkedDeleted (candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel (candidate_id))))) {
365
408
top_candidates.emplace (dist, candidate_id);
409
+ if (!bare_bone_search && stop_condition) {
410
+ stop_condition->add_point_to_result (getExternalLabel (candidate_id), currObj1, dist);
411
+ }
412
+ }
366
413
367
- if (top_candidates.size () > ef)
414
+ bool flag_remove_extra = false ;
415
+ if (!bare_bone_search && stop_condition) {
416
+ flag_remove_extra = stop_condition->should_remove_extra ();
417
+ } else {
418
+ flag_remove_extra = top_candidates.size () > ef;
419
+ }
420
+ while (flag_remove_extra) {
421
+ tableint id = top_candidates.top ().second ;
368
422
top_candidates.pop ();
423
+ if (!bare_bone_search && stop_condition) {
424
+ stop_condition->remove_point_from_result (getExternalLabel (id), getDataByInternalId (id), dist);
425
+ flag_remove_extra = stop_condition->should_remove_extra ();
426
+ } else {
427
+ flag_remove_extra = top_candidates.size () > ef;
428
+ }
429
+ }
369
430
370
431
if (!top_candidates.empty ())
371
432
lowerBound = top_candidates.top ().first ;
@@ -380,8 +441,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
380
441
381
442
382
443
void getNeighborsByHeuristic2 (
383
- std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> &top_candidates,
384
- const size_t M) {
444
+ std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> &top_candidates,
445
+ const size_t M) {
385
446
if (top_candidates.size () < M) {
386
447
return ;
387
448
}
@@ -573,8 +634,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
573
634
if (new_max_elements < cur_element_count)
574
635
throw std::runtime_error (" Cannot resize, max element is less than the current number of elements" );
575
636
576
- delete visited_list_pool_;
577
- visited_list_pool_ = new VisitedListPool (1 , new_max_elements);
637
+ visited_list_pool_.reset (new VisitedListPool (1 , new_max_elements));
578
638
579
639
element_levels_.resize (new_max_elements);
580
640
@@ -595,6 +655,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
595
655
max_elements_ = new_max_elements;
596
656
}
597
657
658
+ size_t indexFileSize () const {
659
+ size_t size = 0 ;
660
+ size += sizeof (offsetLevel0_);
661
+ size += sizeof (max_elements_);
662
+ size += sizeof (cur_element_count);
663
+ size += sizeof (size_data_per_element_);
664
+ size += sizeof (label_offset_);
665
+ size += sizeof (offsetData_);
666
+ size += sizeof (maxlevel_);
667
+ size += sizeof (enterpoint_node_);
668
+ size += sizeof (maxM_);
669
+
670
+ size += sizeof (maxM0_);
671
+ size += sizeof (M_);
672
+ size += sizeof (mult_);
673
+ size += sizeof (ef_construction_);
674
+
675
+ size += cur_element_count * size_data_per_element_;
676
+
677
+ for (size_t i = 0 ; i < cur_element_count; i++) {
678
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0 ;
679
+ size += sizeof (linkListSize);
680
+ size += linkListSize;
681
+ }
682
+ return size;
683
+ }
598
684
599
685
void saveIndex (const std::string &location) {
600
686
std::ofstream output (location, std::ios::binary);
@@ -633,6 +719,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
633
719
if (!input.is_open ())
634
720
throw std::runtime_error (" Cannot open file" );
635
721
722
+ clear ();
636
723
// get file size:
637
724
input.seekg (0 , input.end );
638
725
std::streampos total_filesize = input.tellg ();
@@ -698,7 +785,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
698
785
std::vector<std::mutex>(max_elements).swap (link_list_locks_);
699
786
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap (label_op_locks_);
700
787
701
- visited_list_pool_ = new VisitedListPool (1 , max_elements);
788
+ visited_list_pool_. reset ( new VisitedListPool (1 , max_elements) );
702
789
703
790
linkLists_ = (char **) malloc (sizeof (void *) * max_elements);
704
791
if (linkLists_ == nullptr )
@@ -752,7 +839,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
752
839
size_t dim = *((size_t *) dist_func_param_);
753
840
std::vector<data_t > data;
754
841
data_t * data_ptr = (data_t *) data_ptrv;
755
- for (int i = 0 ; i < dim; i++) {
842
+ for (size_t i = 0 ; i < dim; i++) {
756
843
data.push_back (*data_ptr);
757
844
data_ptr += 1 ;
758
845
}
@@ -1216,11 +1303,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1216
1303
}
1217
1304
1218
1305
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates;
1219
- if (num_deleted_) {
1220
- top_candidates = searchBaseLayerST<true , true >(
1306
+ bool bare_bone_search = !num_deleted_ && !isIdAllowed;
1307
+ if (bare_bone_search) {
1308
+ top_candidates = searchBaseLayerST<true >(
1221
1309
currObj, query_data, std::max (ef_, k), isIdAllowed);
1222
1310
} else {
1223
- top_candidates = searchBaseLayerST<false , true >(
1311
+ top_candidates = searchBaseLayerST<false >(
1224
1312
currObj, query_data, std::max (ef_, k), isIdAllowed);
1225
1313
}
1226
1314
@@ -1236,6 +1324,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1236
1324
}
1237
1325
1238
1326
1327
+ std::vector<std::pair<dist_t , labeltype >>
1328
+ searchStopConditionClosest (
1329
+ const void *query_data,
1330
+ BaseSearchStopCondition<dist_t >& stop_condition,
1331
+ BaseFilterFunctor* isIdAllowed = nullptr ) const {
1332
+ std::vector<std::pair<dist_t , labeltype >> result;
1333
+ if (cur_element_count == 0 ) return result;
1334
+
1335
+ tableint currObj = enterpoint_node_;
1336
+ dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
1337
+
1338
+ for (int level = maxlevel_; level > 0 ; level--) {
1339
+ bool changed = true ;
1340
+ while (changed) {
1341
+ changed = false ;
1342
+ unsigned int *data;
1343
+
1344
+ data = (unsigned int *) get_linklist (currObj, level);
1345
+ int size = getListCount (data);
1346
+ metric_hops++;
1347
+ metric_distance_computations+=size;
1348
+
1349
+ tableint *datal = (tableint *) (data + 1 );
1350
+ for (int i = 0 ; i < size; i++) {
1351
+ tableint cand = datal[i];
1352
+ if (cand < 0 || cand > max_elements_)
1353
+ throw std::runtime_error (" cand error" );
1354
+ dist_t d = fstdistfunc_ (query_data, getDataByInternalId (cand), dist_func_param_);
1355
+
1356
+ if (d < curdist) {
1357
+ curdist = d;
1358
+ currObj = cand;
1359
+ changed = true ;
1360
+ }
1361
+ }
1362
+ }
1363
+ }
1364
+
1365
+ std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates;
1366
+ top_candidates = searchBaseLayerST<false >(currObj, query_data, 0 , isIdAllowed, &stop_condition);
1367
+
1368
+ size_t sz = top_candidates.size ();
1369
+ result.resize (sz);
1370
+ while (!top_candidates.empty ()) {
1371
+ result[--sz] = top_candidates.top ();
1372
+ top_candidates.pop ();
1373
+ }
1374
+
1375
+ stop_condition.filter_results (result);
1376
+
1377
+ return result;
1378
+ }
1379
+
1380
+
1239
1381
void checkIntegrity () {
1240
1382
int connections_checked = 0 ;
1241
1383
std::vector <int > inbound_connections_num (cur_element_count, 0 );
@@ -1246,7 +1388,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1246
1388
tableint *data = (tableint *) (ll_cur + 1 );
1247
1389
std::unordered_set<tableint> s;
1248
1390
for (int j = 0 ; j < size; j++) {
1249
- assert (data[j] > 0 );
1250
1391
assert (data[j] < cur_element_count);
1251
1392
assert (data[j] != i);
1252
1393
inbound_connections_num[data[j]]++;
0 commit comments