Skip to content

Commit 8d7f410

Browse files
authored
Update hnswlib to v0.8.0 (#17)
* updated to hnswlib v0.8.0 * implemenmted `HNSWLib.BFIndex.get_max_elements/1` * implemenmted `HNSWLib.BFIndex.get_current_count/1` * implemenmted `HNSWLib.BFIndex.{set_num_threads/2,get_num_threads/1}` * implemenmted `HNSWLib.Index.index_file_size/1` * [fix-windows] iterators being deleted after erasing
1 parent f9f0264 commit 8d7f410

13 files changed

+746
-42
lines changed

.github/workflows/ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ jobs:
4242

4343
- uses: erlef/setup-beam@v1
4444
with:
45-
otp-version: 25
46-
elixir-version: 1.14
45+
otp-version: "26.2.1"
46+
elixir-version: "1.16.0"
4747

4848
- uses: ilammy/msvc-dev-cmd@v1
4949
with:

3rd_party/hnswlib/bruteforce.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
8484

8585

8686
void removePoint(labeltype cur_external) {
87-
size_t cur_c = dict_external_to_internal[cur_external];
87+
std::unique_lock<std::mutex> lock(index_lock);
8888

89-
dict_external_to_internal.erase(cur_external);
89+
auto found = dict_external_to_internal.find(cur_external);
90+
if (found == dict_external_to_internal.end()) {
91+
return;
92+
}
93+
94+
size_t cur_c = found->second;
95+
dict_external_to_internal.erase(found);
9096

9197
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
9298
dict_external_to_internal[label] = cur_c;
@@ -106,7 +112,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
106112
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
107113
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
108114
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
109-
topResults.push(std::pair<dist_t, labeltype>(dist, label));
115+
topResults.emplace(dist, label);
110116
}
111117
}
112118
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
@@ -115,7 +121,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
115121
if (dist <= lastdist) {
116122
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
117123
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
118-
topResults.push(std::pair<dist_t, labeltype>(dist, label));
124+
topResults.emplace(dist, label);
119125
}
120126
if (topResults.size() > k)
121127
topResults.pop();

3rd_party/hnswlib/hnswalg.h

+166-25
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <assert.h>
99
#include <unordered_set>
1010
#include <list>
11+
#include <memory>
1112

1213
namespace hnswlib {
1314
typedef unsigned int tableint;
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
3334
double mult_{0.0}, revSize_{0.0};
3435
int maxlevel_{0};
3536

36-
VisitedListPool *visited_list_pool_{nullptr};
37+
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
3738

3839
// Locks operations with element by label value
3940
mutable std::vector<std::mutex> label_op_locks_;
@@ -92,16 +93,22 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
9293
size_t ef_construction = 200,
9394
size_t random_seed = 100,
9495
bool allow_replace_deleted = false)
95-
: link_list_locks_(max_elements),
96-
label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
96+
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
97+
link_list_locks_(max_elements),
9798
element_levels_(max_elements),
9899
allow_replace_deleted_(allow_replace_deleted) {
99100
max_elements_ = max_elements;
100101
num_deleted_ = 0;
101102
data_size_ = s->get_data_size();
102103
fstdistfunc_ = s->get_dist_func();
103104
dist_func_param_ = s->get_dist_func_param();
104-
M_ = M;
105+
if ( M <= 10000 ) {
106+
M_ = M;
107+
} else {
108+
HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
109+
HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
110+
M_ = 10000;
111+
}
105112
maxM_ = M_;
106113
maxM0_ = M_ * 2;
107114
ef_construction_ = std::max(ef_construction, M_);
@@ -122,7 +129,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
122129

123130
cur_element_count = 0;
124131

125-
visited_list_pool_ = new VisitedListPool(1, max_elements);
132+
visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
126133

127134
// initializations for special treatment of the first node
128135
enterpoint_node_ = -1;
@@ -138,13 +145,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
138145

139146

140147
~HierarchicalNSW() {
148+
clear();
149+
}
150+
151+
void clear() {
141152
free(data_level0_memory_);
153+
data_level0_memory_ = nullptr;
142154
for (tableint i = 0; i < cur_element_count; i++) {
143155
if (element_levels_[i] > 0)
144156
free(linkLists_[i]);
145157
}
146158
free(linkLists_);
147-
delete visited_list_pool_;
159+
linkLists_ = nullptr;
160+
cur_element_count = 0;
161+
visited_list_pool_.reset(nullptr);
148162
}
149163

150164

@@ -291,9 +305,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
291305
}
292306

293307

294-
template <bool has_deletions, bool collect_metrics = false>
308+
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
309+
template <bool bare_bone_search = true, bool collect_metrics = false>
295310
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
296-
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
311+
searchBaseLayerST(
312+
tableint ep_id,
313+
const void *data_point,
314+
size_t ef,
315+
BaseFilterFunctor* isIdAllowed = nullptr,
316+
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
297317
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
298318
vl_type *visited_array = vl->mass;
299319
vl_type visited_array_tag = vl->curV;
@@ -302,10 +322,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
302322
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
303323

304324
dist_t lowerBound;
305-
if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
306-
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
325+
if (bare_bone_search ||
326+
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
327+
char* ep_data = getDataByInternalId(ep_id);
328+
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
307329
lowerBound = dist;
308330
top_candidates.emplace(dist, ep_id);
331+
if (!bare_bone_search && stop_condition) {
332+
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
333+
}
309334
candidate_set.emplace(-dist, ep_id);
310335
} else {
311336
lowerBound = std::numeric_limits<dist_t>::max();
@@ -316,9 +341,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
316341

317342
while (!candidate_set.empty()) {
318343
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
344+
dist_t candidate_dist = -current_node_pair.first;
319345

320-
if ((-current_node_pair.first) > lowerBound &&
321-
(top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
346+
bool flag_stop_search;
347+
if (bare_bone_search) {
348+
flag_stop_search = candidate_dist > lowerBound;
349+
} else {
350+
if (stop_condition) {
351+
flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
352+
} else {
353+
flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
354+
}
355+
}
356+
if (flag_stop_search) {
322357
break;
323358
}
324359
candidate_set.pop();
@@ -353,19 +388,45 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
353388
char *currObj1 = (getDataByInternalId(candidate_id));
354389
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
355390

356-
if (top_candidates.size() < ef || lowerBound > dist) {
391+
bool flag_consider_candidate;
392+
if (!bare_bone_search && stop_condition) {
393+
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
394+
} else {
395+
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
396+
}
397+
398+
if (flag_consider_candidate) {
357399
candidate_set.emplace(-dist, candidate_id);
358400
#ifdef USE_SSE
359401
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
360402
offsetLevel0_, ///////////
361403
_MM_HINT_T0); ////////////////////////
362404
#endif
363405

364-
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
406+
if (bare_bone_search ||
407+
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
365408
top_candidates.emplace(dist, candidate_id);
409+
if (!bare_bone_search && stop_condition) {
410+
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
411+
}
412+
}
366413

367-
if (top_candidates.size() > ef)
414+
bool flag_remove_extra = false;
415+
if (!bare_bone_search && stop_condition) {
416+
flag_remove_extra = stop_condition->should_remove_extra();
417+
} else {
418+
flag_remove_extra = top_candidates.size() > ef;
419+
}
420+
while (flag_remove_extra) {
421+
tableint id = top_candidates.top().second;
368422
top_candidates.pop();
423+
if (!bare_bone_search && stop_condition) {
424+
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
425+
flag_remove_extra = stop_condition->should_remove_extra();
426+
} else {
427+
flag_remove_extra = top_candidates.size() > ef;
428+
}
429+
}
369430

370431
if (!top_candidates.empty())
371432
lowerBound = top_candidates.top().first;
@@ -380,8 +441,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
380441

381442

382443
void getNeighborsByHeuristic2(
383-
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
384-
const size_t M) {
444+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
445+
const size_t M) {
385446
if (top_candidates.size() < M) {
386447
return;
387448
}
@@ -573,8 +634,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
573634
if (new_max_elements < cur_element_count)
574635
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
575636

576-
delete visited_list_pool_;
577-
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
637+
visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
578638

579639
element_levels_.resize(new_max_elements);
580640

@@ -595,6 +655,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
595655
max_elements_ = new_max_elements;
596656
}
597657

658+
size_t indexFileSize() const {
659+
size_t size = 0;
660+
size += sizeof(offsetLevel0_);
661+
size += sizeof(max_elements_);
662+
size += sizeof(cur_element_count);
663+
size += sizeof(size_data_per_element_);
664+
size += sizeof(label_offset_);
665+
size += sizeof(offsetData_);
666+
size += sizeof(maxlevel_);
667+
size += sizeof(enterpoint_node_);
668+
size += sizeof(maxM_);
669+
670+
size += sizeof(maxM0_);
671+
size += sizeof(M_);
672+
size += sizeof(mult_);
673+
size += sizeof(ef_construction_);
674+
675+
size += cur_element_count * size_data_per_element_;
676+
677+
for (size_t i = 0; i < cur_element_count; i++) {
678+
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
679+
size += sizeof(linkListSize);
680+
size += linkListSize;
681+
}
682+
return size;
683+
}
598684

599685
void saveIndex(const std::string &location) {
600686
std::ofstream output(location, std::ios::binary);
@@ -633,6 +719,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
633719
if (!input.is_open())
634720
throw std::runtime_error("Cannot open file");
635721

722+
clear();
636723
// get file size:
637724
input.seekg(0, input.end);
638725
std::streampos total_filesize = input.tellg();
@@ -698,7 +785,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
698785
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
699786
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
700787

701-
visited_list_pool_ = new VisitedListPool(1, max_elements);
788+
visited_list_pool_.reset(new VisitedListPool(1, max_elements));
702789

703790
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
704791
if (linkLists_ == nullptr)
@@ -752,7 +839,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
752839
size_t dim = *((size_t *) dist_func_param_);
753840
std::vector<data_t> data;
754841
data_t* data_ptr = (data_t*) data_ptrv;
755-
for (int i = 0; i < dim; i++) {
842+
for (size_t i = 0; i < dim; i++) {
756843
data.push_back(*data_ptr);
757844
data_ptr += 1;
758845
}
@@ -1216,11 +1303,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12161303
}
12171304

12181305
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1219-
if (num_deleted_) {
1220-
top_candidates = searchBaseLayerST<true, true>(
1306+
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
1307+
if (bare_bone_search) {
1308+
top_candidates = searchBaseLayerST<true>(
12211309
currObj, query_data, std::max(ef_, k), isIdAllowed);
12221310
} else {
1223-
top_candidates = searchBaseLayerST<false, true>(
1311+
top_candidates = searchBaseLayerST<false>(
12241312
currObj, query_data, std::max(ef_, k), isIdAllowed);
12251313
}
12261314

@@ -1236,6 +1324,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12361324
}
12371325

12381326

1327+
std::vector<std::pair<dist_t, labeltype >>
1328+
searchStopConditionClosest(
1329+
const void *query_data,
1330+
BaseSearchStopCondition<dist_t>& stop_condition,
1331+
BaseFilterFunctor* isIdAllowed = nullptr) const {
1332+
std::vector<std::pair<dist_t, labeltype >> result;
1333+
if (cur_element_count == 0) return result;
1334+
1335+
tableint currObj = enterpoint_node_;
1336+
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1337+
1338+
for (int level = maxlevel_; level > 0; level--) {
1339+
bool changed = true;
1340+
while (changed) {
1341+
changed = false;
1342+
unsigned int *data;
1343+
1344+
data = (unsigned int *) get_linklist(currObj, level);
1345+
int size = getListCount(data);
1346+
metric_hops++;
1347+
metric_distance_computations+=size;
1348+
1349+
tableint *datal = (tableint *) (data + 1);
1350+
for (int i = 0; i < size; i++) {
1351+
tableint cand = datal[i];
1352+
if (cand < 0 || cand > max_elements_)
1353+
throw std::runtime_error("cand error");
1354+
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1355+
1356+
if (d < curdist) {
1357+
curdist = d;
1358+
currObj = cand;
1359+
changed = true;
1360+
}
1361+
}
1362+
}
1363+
}
1364+
1365+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1366+
top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
1367+
1368+
size_t sz = top_candidates.size();
1369+
result.resize(sz);
1370+
while (!top_candidates.empty()) {
1371+
result[--sz] = top_candidates.top();
1372+
top_candidates.pop();
1373+
}
1374+
1375+
stop_condition.filter_results(result);
1376+
1377+
return result;
1378+
}
1379+
1380+
12391381
void checkIntegrity() {
12401382
int connections_checked = 0;
12411383
std::vector <int > inbound_connections_num(cur_element_count, 0);
@@ -1246,7 +1388,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12461388
tableint *data = (tableint *) (ll_cur + 1);
12471389
std::unordered_set<tableint> s;
12481390
for (int j = 0; j < size; j++) {
1249-
assert(data[j] > 0);
12501391
assert(data[j] < cur_element_count);
12511392
assert(data[j] != i);
12521393
inbound_connections_num[data[j]]++;

0 commit comments

Comments
 (0)