19
19
import org .apache .lucene .store .IndexInput ;
20
20
import org .apache .lucene .store .IndexOutput ;
21
21
import org .apache .lucene .store .MMapDirectory ;
22
- import org .apache .lucene .util .hnsw .RandomVectorScorer ;
23
22
import org .apache .lucene .util .hnsw .RandomVectorScorerSupplier ;
23
+ import org .apache .lucene .util .hnsw .UpdateableRandomVectorScorer ;
24
24
import org .apache .lucene .util .quantization .QuantizedByteVectorValues ;
25
25
import org .apache .lucene .util .quantization .ScalarQuantizer ;
26
26
50
50
// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100)
51
51
public class VectorScorerFactoryTests extends AbstractVectorTestCase {
52
52
53
+ private static final float DELTA = 1e-4f ;
54
+
53
55
// bounds of the range of values that can be seen by int7 scalar quantized vectors
54
56
static final byte MIN_INT7_VALUE = 0 ;
55
57
static final byte MAX_INT7_VALUE = 127 ;
@@ -99,10 +101,13 @@ void testSimpleImpl(long maxChunkSize) throws IOException {
99
101
float scc = values .getScalarQuantizer ().getConstantMultiplier ();
100
102
float expected = luceneScore (sim , vec1 , vec2 , scc , vec1Correction , vec2Correction );
101
103
102
- var luceneSupplier = luceneScoreSupplier (values , VectorSimilarityType .of (sim )).scorer (0 );
104
+ var luceneSupplier = luceneScoreSupplier (values , VectorSimilarityType .of (sim )).scorer ();
105
+ luceneSupplier .setScoringOrdinal (0 );
103
106
assertThat (luceneSupplier .score (1 ), equalTo (expected ));
104
107
var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , scc ).get ();
105
- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
108
+ var scorer = supplier .scorer ();
109
+ scorer .setScoringOrdinal (0 );
110
+ assertThat (scorer .score (1 ), equalTo (expected ));
106
111
107
112
if (Runtime .version ().feature () >= 22 ) {
108
113
var qScorer = factory .getInt7SQVectorScorer (VectorSimilarityType .of (sim ), values , query1 ).get ();
@@ -134,24 +139,32 @@ public void testNonNegativeDotProduct() throws IOException {
134
139
float expected = 0f ;
135
140
assertThat (luceneScore (DOT_PRODUCT , vec1 , vec2 , 1 , -5 , -5 ), equalTo (expected ));
136
141
var supplier = factory .getInt7SQVectorScorerSupplier (DOT_PRODUCT , in , values , 1 ).get ();
137
- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
138
- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
142
+ var scorer = supplier .scorer ();
143
+ scorer .setScoringOrdinal (0 );
144
+ assertThat (scorer .score (1 ), equalTo (expected ));
145
+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
139
146
// max inner product
140
147
expected = luceneScore (MAXIMUM_INNER_PRODUCT , vec1 , vec2 , 1 , -5 , -5 );
141
148
supplier = factory .getInt7SQVectorScorerSupplier (MAXIMUM_INNER_PRODUCT , in , values , 1 ).get ();
142
- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
143
- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
149
+ scorer = supplier .scorer ();
150
+ scorer .setScoringOrdinal (0 );
151
+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
152
+ assertThat (scorer .score (1 ), equalTo (expected ));
144
153
// cosine
145
154
expected = 0f ;
146
155
assertThat (luceneScore (COSINE , vec1 , vec2 , 1 , -5 , -5 ), equalTo (expected ));
147
156
supplier = factory .getInt7SQVectorScorerSupplier (COSINE , in , values , 1 ).get ();
148
- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
149
- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
157
+ scorer = supplier .scorer ();
158
+ scorer .setScoringOrdinal (0 );
159
+ assertThat (scorer .score (1 ), equalTo (expected ));
160
+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
150
161
// euclidean
151
162
expected = luceneScore (EUCLIDEAN , vec1 , vec2 , 1 , -5 , -5 );
152
163
supplier = factory .getInt7SQVectorScorerSupplier (EUCLIDEAN , in , values , 1 ).get ();
153
- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
154
- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
164
+ scorer = supplier .scorer ();
165
+ scorer .setScoringOrdinal (0 );
166
+ assertThat (scorer .score (1 ), equalTo (expected ));
167
+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
155
168
}
156
169
}
157
170
}
@@ -208,7 +221,9 @@ void testRandomSupplier(long maxChunkSize, Function<Integer, byte[]> byteArraySu
208
221
var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
209
222
float expected = luceneScore (sim , vectors [idx0 ], vectors [idx1 ], correction , offsets [idx0 ], offsets [idx1 ]);
210
223
var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
211
- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
224
+ var scorer = supplier .scorer ();
225
+ scorer .setScoringOrdinal (idx0 );
226
+ assertThat (scorer .score (idx1 ), equalTo (expected ));
212
227
}
213
228
}
214
229
}
@@ -265,7 +280,7 @@ void testRandomScorerImpl(long maxChunkSize, Function<Integer, float[]> floatArr
265
280
266
281
var expected = luceneScore (sim , qVectors [idx0 ], qVectors [idx1 ], correction , corrections [idx0 ], corrections [idx1 ]);
267
282
var scorer = factory .getInt7SQVectorScorer (VectorSimilarityType .of (sim ), values , vectors [idx0 ]).get ();
268
- assertThat (scorer .score (idx1 ), equalTo ( expected ) );
283
+ assertEquals (scorer .score (idx1 ), expected , DELTA );
269
284
}
270
285
}
271
286
}
@@ -313,7 +328,9 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Functi
313
328
var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
314
329
float expected = luceneScore (sim , vectors [idx0 ], vectors [idx1 ], correction , offsets [idx0 ], offsets [idx1 ]);
315
330
var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
316
- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
331
+ var scorer = supplier .scorer ();
332
+ scorer .setScoringOrdinal (idx0 );
333
+ assertThat (scorer .score (idx1 ), equalTo (expected ));
317
334
}
318
335
}
319
336
}
@@ -352,7 +369,9 @@ public void testLarge() throws IOException {
352
369
var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
353
370
float expected = luceneScore (sim , vector (idx0 , dims ), vector (idx1 , dims ), correction , off0 , off1 );
354
371
var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
355
- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
372
+ var scorer = supplier .scorer ();
373
+ scorer .setScoringOrdinal (idx0 );
374
+ assertThat (scorer .score (idx1 ), equalTo (expected ));
356
375
}
357
376
}
358
377
}
@@ -391,8 +410,8 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
391
410
var values = vectorValues (dims , 4 , in , VectorSimilarityType .of (sim ));
392
411
var scoreSupplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , 1f ).get ();
393
412
var tasks = List .<Callable <Optional <Throwable >>>of (
394
- new ScoreCallable (scoreSupplier .copy ().scorer (0 ) , 1 , expectedScore1 ),
395
- new ScoreCallable (scoreSupplier .copy ().scorer (2 ) , 3 , expectedScore2 )
413
+ new ScoreCallable (scoreSupplier .copy ().scorer (), 0 , 1 , expectedScore1 ),
414
+ new ScoreCallable (scoreSupplier .copy ().scorer (), 2 , 3 , expectedScore2 )
396
415
);
397
416
var executor = Executors .newFixedThreadPool (2 );
398
417
var results = executor .invokeAll (tasks );
@@ -408,14 +427,19 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
408
427
409
428
static class ScoreCallable implements Callable <Optional <Throwable >> {
410
429
411
- final RandomVectorScorer scorer ;
430
+ final UpdateableRandomVectorScorer scorer ;
412
431
final int ord ;
413
432
final float expectedScore ;
414
433
415
- ScoreCallable (RandomVectorScorer scorer , int ord , float expectedScore ) {
416
- this .scorer = scorer ;
417
- this .ord = ord ;
418
- this .expectedScore = expectedScore ;
434
+ ScoreCallable (UpdateableRandomVectorScorer scorer , int queryOrd , int ord , float expectedScore ) {
435
+ try {
436
+ this .scorer = scorer ;
437
+ this .scorer .setScoringOrdinal (queryOrd );
438
+ this .ord = ord ;
439
+ this .expectedScore = expectedScore ;
440
+ } catch (IOException e ) {
441
+ throw new RuntimeException (e );
442
+ }
419
443
}
420
444
421
445
@ Override
0 commit comments