Skip to content

Commit 30b0195

Browse files
committed
Add a HNSW collector that exits early when nearest neighbor queue saturates (#14094)
* Add a HNSW early termination based on nearest neighbor queue saturation Co-authored-by: Benjamin Trent <[email protected]> (cherry picked from commit 525bf34)
1 parent 90e6030 commit 30b0195

10 files changed

+714
-1
lines changed

Diff for: lucene/CHANGES.txt

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ New Features
6565

6666
* GITHUB#14412: Allow skip cache factor to be updated dynamically. (Sagar Upadhyaya)
6767

68+
* GITHUB#14094: New KNN query that early terminates when HNSW nearest neighbor queue saturates. (Tommaso Teofili)
69+
6870
Improvements
6971
---------------------
7072

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import org.apache.lucene.search.knn.KnnSearchStrategy;
21+
22+
/**
23+
* A {@link KnnCollector.Decorator} that early exits when nearest neighbor queue keeps saturating
24+
* beyond a 'patience' parameter. This records the rate of collection of new nearest neighbors in
25+
* the {@code delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for
26+
* a number of consecutive node visits (e.g., the patience parameter), this early terminates.
27+
*
28+
* @lucene.experimental
29+
*/
30+
public class HnswQueueSaturationCollector extends KnnCollector.Decorator {
31+
32+
private final KnnCollector delegate;
33+
private final double saturationThreshold;
34+
private final int patience;
35+
private boolean patienceFinished;
36+
private int countSaturated;
37+
private int previousQueueSize;
38+
private int currentQueueSize;
39+
40+
HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) {
41+
super(delegate);
42+
this.delegate = delegate;
43+
this.previousQueueSize = 0;
44+
this.currentQueueSize = 0;
45+
this.countSaturated = 0;
46+
this.patienceFinished = false;
47+
this.saturationThreshold = saturationThreshold;
48+
this.patience = patience;
49+
}
50+
51+
@Override
52+
public boolean earlyTerminated() {
53+
return delegate.earlyTerminated() || patienceFinished;
54+
}
55+
56+
@Override
57+
public boolean collect(int docId, float similarity) {
58+
boolean collect = delegate.collect(docId, similarity);
59+
if (collect) {
60+
currentQueueSize++;
61+
}
62+
return collect;
63+
}
64+
65+
@Override
66+
public TopDocs topDocs() {
67+
TopDocs topDocs;
68+
if (patienceFinished && delegate.earlyTerminated() == false) {
69+
// this avoids re-running exact search in the filtered scenario when patience is exhausted
70+
TopDocs delegateDocs = delegate.topDocs();
71+
TotalHits totalHits =
72+
new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO);
73+
topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs);
74+
} else {
75+
topDocs = delegate.topDocs();
76+
}
77+
return topDocs;
78+
}
79+
80+
public void nextCandidate() {
81+
double queueSaturation =
82+
(double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize;
83+
previousQueueSize = currentQueueSize;
84+
if (queueSaturation >= saturationThreshold) {
85+
countSaturated++;
86+
} else {
87+
countSaturated = 0;
88+
}
89+
if (countSaturated > patience) {
90+
patienceFinished = true;
91+
}
92+
}
93+
94+
@Override
95+
public KnnSearchStrategy getSearchStrategy() {
96+
KnnSearchStrategy delegateStrategy = delegate.getSearchStrategy();
97+
assert delegateStrategy instanceof KnnSearchStrategy.Hnsw;
98+
return new KnnSearchStrategy.Patience(
99+
this, ((KnnSearchStrategy.Hnsw) delegateStrategy).filteredSearchThreshold());
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.search;
18+
19+
import java.io.IOException;
20+
import java.util.Objects;
21+
import org.apache.lucene.index.FieldInfo;
22+
import org.apache.lucene.index.LeafReaderContext;
23+
import org.apache.lucene.index.QueryTimeout;
24+
import org.apache.lucene.search.knn.KnnCollectorManager;
25+
import org.apache.lucene.search.knn.KnnSearchStrategy;
26+
import org.apache.lucene.util.Bits;
27+
28+
/**
29+
* This is a version of knn vector query that exits early when HNSW queue saturates over a {@code
30+
* #saturationThreshold} for more than {@code #patience} times.
31+
*
32+
* <p>See <a
33+
* href="https://cs.uwaterloo.ca/~jimmylin/publications/Teofili_Lin_ECIR2025.pdf">"Patience in
34+
* Proximity: A Simple Early Termination Strategy for HNSW Graph Traversal in Approximate k-Nearest
35+
* Neighbor Search"</a> (Teofili and Lin). In ECIR '25: Proceedings of the 47th European Conference
36+
* on Information Retrieval.
37+
*
38+
* @lucene.experimental
39+
*/
40+
public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery {
41+
42+
private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d;
43+
44+
private final int patience;
45+
private final double saturationThreshold;
46+
47+
final AbstractKnnVectorQuery delegate;
48+
49+
/**
50+
* Construct a new PatienceKnnVectorQuery instance for a float vector field
51+
*
52+
* @param knnQuery the knn query to be seeded
53+
* @param saturationThreshold the early exit saturation threshold
54+
* @param patience the patience parameter
55+
* @return a new PatienceKnnVectorQuery instance
56+
* @lucene.experimental
57+
*/
58+
public static PatienceKnnVectorQuery fromFloatQuery(
59+
KnnFloatVectorQuery knnQuery, double saturationThreshold, int patience) {
60+
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience);
61+
}
62+
63+
/**
64+
* Construct a new PatienceKnnVectorQuery instance for a float vector field
65+
*
66+
* @param knnQuery the knn query to be seeded
67+
* @return a new PatienceKnnVectorQuery instance
68+
* @lucene.experimental
69+
*/
70+
public static PatienceKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery) {
71+
return new PatienceKnnVectorQuery(
72+
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery));
73+
}
74+
75+
/**
76+
* Construct a new PatienceKnnVectorQuery instance for a byte vector field
77+
*
78+
* @param knnQuery the knn query to be seeded
79+
* @param saturationThreshold the early exit saturation threshold
80+
* @param patience the patience parameter
81+
* @return a new PatienceKnnVectorQuery instance
82+
* @lucene.experimental
83+
*/
84+
public static PatienceKnnVectorQuery fromByteQuery(
85+
KnnByteVectorQuery knnQuery, double saturationThreshold, int patience) {
86+
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience);
87+
}
88+
89+
/**
90+
* Construct a new PatienceKnnVectorQuery instance for a byte vector field
91+
*
92+
* @param knnQuery the knn query to be seeded
93+
* @return a new PatienceKnnVectorQuery instance
94+
* @lucene.experimental
95+
*/
96+
public static PatienceKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery) {
97+
return new PatienceKnnVectorQuery(
98+
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery));
99+
}
100+
101+
/**
102+
* Construct a new PatienceKnnVectorQuery instance for seeded vector field
103+
*
104+
* @param knnQuery the knn query to be seeded
105+
* @param saturationThreshold the early exit saturation threshold
106+
* @param patience the patience parameter
107+
* @return a new PatienceKnnVectorQuery instance
108+
* @lucene.experimental
109+
*/
110+
public static PatienceKnnVectorQuery fromSeededQuery(
111+
SeededKnnVectorQuery knnQuery, double saturationThreshold, int patience) {
112+
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience);
113+
}
114+
115+
/**
116+
* Construct a new PatienceKnnVectorQuery instance for seeded vector field
117+
*
118+
* @param knnQuery the knn query to be seeded
119+
* @return a new PatienceKnnVectorQuery instance
120+
* @lucene.experimental
121+
*/
122+
public static PatienceKnnVectorQuery fromSeededQuery(SeededKnnVectorQuery knnQuery) {
123+
return new PatienceKnnVectorQuery(
124+
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery));
125+
}
126+
127+
PatienceKnnVectorQuery(
128+
AbstractKnnVectorQuery knnQuery, double saturationThreshold, int patience) {
129+
super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy);
130+
this.delegate = knnQuery;
131+
this.saturationThreshold = saturationThreshold;
132+
this.patience = patience;
133+
}
134+
135+
private static int defaultPatience(AbstractKnnVectorQuery delegate) {
136+
return Math.max(7, (int) (delegate.k * 0.3));
137+
}
138+
139+
@Override
140+
public String toString(String field) {
141+
return "PatienceKnnVectorQuery{"
142+
+ "saturationThreshold="
143+
+ saturationThreshold
144+
+ ", patience="
145+
+ patience
146+
+ ", delegate="
147+
+ delegate
148+
+ '}';
149+
}
150+
151+
@Override
152+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
153+
return delegate.getKnnCollectorManager(k, searcher);
154+
}
155+
156+
@Override
157+
protected TopDocs approximateSearch(
158+
LeafReaderContext context,
159+
Bits acceptDocs,
160+
int visitedLimit,
161+
KnnCollectorManager knnCollectorManager)
162+
throws IOException {
163+
return delegate.approximateSearch(
164+
context, acceptDocs, visitedLimit, new PatienceCollectorManager(knnCollectorManager));
165+
}
166+
167+
@Override
168+
protected TopDocs exactSearch(
169+
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
170+
throws IOException {
171+
return delegate.exactSearch(context, acceptIterator, queryTimeout);
172+
}
173+
174+
@Override
175+
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
176+
return delegate.mergeLeafResults(perLeafResults);
177+
}
178+
179+
@Override
180+
public void visit(QueryVisitor visitor) {
181+
delegate.visit(visitor);
182+
}
183+
184+
@Override
185+
public boolean equals(Object o) {
186+
if (this == o) return true;
187+
if (o == null || getClass() != o.getClass()) return false;
188+
if (!super.equals(o)) return false;
189+
PatienceKnnVectorQuery that = (PatienceKnnVectorQuery) o;
190+
return saturationThreshold == that.saturationThreshold
191+
&& patience == that.patience
192+
&& Objects.equals(delegate, that.delegate);
193+
}
194+
195+
@Override
196+
public int hashCode() {
197+
return Objects.hash(super.hashCode(), saturationThreshold, patience, delegate);
198+
}
199+
200+
@Override
201+
public String getField() {
202+
return delegate.getField();
203+
}
204+
205+
@Override
206+
public int getK() {
207+
return delegate.getK();
208+
}
209+
210+
@Override
211+
public Query getFilter() {
212+
return delegate.getFilter();
213+
}
214+
215+
@Override
216+
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
217+
return delegate.createVectorScorer(context, fi);
218+
}
219+
220+
class PatienceCollectorManager implements KnnCollectorManager {
221+
final KnnCollectorManager knnCollectorManager;
222+
223+
PatienceCollectorManager(KnnCollectorManager knnCollectorManager) {
224+
this.knnCollectorManager = knnCollectorManager;
225+
}
226+
227+
@Override
228+
public KnnCollector newCollector(
229+
int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx)
230+
throws IOException {
231+
return new HnswQueueSaturationCollector(
232+
knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx),
233+
saturationThreshold,
234+
patience);
235+
}
236+
}
237+
}

0 commit comments

Comments
 (0)