Skip to content

Enable sort optimization on float and half_float #126342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
5 changes: 5 additions & 0 deletions docs/changelog/126342.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126342
summary: Enable sort optimization on float and `half_float`
area: Search
type: enhancement
issues: []

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.MultiValueMode;
Expand Down Expand Up @@ -46,7 +47,7 @@ public enum NumericType {
LONG(false, SortField.Type.LONG, CoreValuesSourceType.NUMERIC),
DATE(false, SortField.Type.LONG, CoreValuesSourceType.DATE),
DATE_NANOSECONDS(false, SortField.Type.LONG, CoreValuesSourceType.DATE),
HALF_FLOAT(true, SortField.Type.LONG, CoreValuesSourceType.NUMERIC),
HALF_FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC),
FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC),
DOUBLE(true, SortField.Type.DOUBLE, CoreValuesSourceType.NUMERIC);

Expand Down Expand Up @@ -95,11 +96,13 @@ public final SortField sortField(
* 3. We Aren't using max or min to resolve the duplicates.
* 4. We have to cast the results to another type.
*/
if (sortRequiresCustomComparator()
|| nested != null
boolean requiresCustomComparator = nested != null
|| (sortMode != MultiValueMode.MAX && sortMode != MultiValueMode.MIN)
|| targetNumericType != getNumericType()) {
return new SortField(getFieldName(), source, reverse);
|| targetNumericType != getNumericType();
if (sortRequiresCustomComparator() || requiresCustomComparator) {
SortField sortField = new SortField(getFieldName(), source, reverse);
sortField.setOptimizeSortWithPoints(requiresCustomComparator == false && isIndexed());
return sortField;
}

SortedNumericSelector.Type selectorType = sortMode == MultiValueMode.MAX
Expand All @@ -108,20 +111,18 @@ public final SortField sortField(
SortField sortField = new SortedNumericSortField(getFieldName(), getNumericType().sortFieldType, reverse, selectorType);
sortField.setMissingValue(source.missingObject(missingValue, reverse));

// TODO: Now that numeric sort uses indexed points to skip over non-competitive documents,
// Lucene 9 requires that the same data/type is stored in points and doc values.
// We break this assumption in ES by using the wider numeric sort type for every field,
// (e.g. shorts use longs and floats use doubles). So for now we forbid the usage of
// points in numeric sort on field types that use a different sort type.
// We could expose these optimizations for all numeric types but that would require
// to rewrite the logic to handle types when merging results coming from different
// indices.
// TODO: enable sort optimization for BYTE, SHORT and INT types
// They can use custom comparator logic, similarly to HalfFloatValuesComparatorSource.
// The problem comes from the fact that we use SortField.Type.LONG for all these types.
// Investigate how to resolve this.
switch (getNumericType()) {
case DATE_NANOSECONDS:
case DATE:
case LONG:
case DOUBLE:
// longs, doubles and dates use the same type for doc-values and points.
case FLOAT:
// longs, doubles and dates use the same type for doc-values and points
// floats uses longs for doc-values, but Lucene's FloatComparator::getValueForDoc converts long value to float
sortField.setOptimizeSortWithPoints(isIndexed());
break;

Expand Down Expand Up @@ -199,7 +200,8 @@ private XFieldComparatorSource comparatorSource(
Nested nested
) {
return switch (targetNumericType) {
case HALF_FLOAT, FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more a question than feedback. Since HALF_FLOAT is no longer handled by FloatValuesComparatorSource are there missing unit tests that are needed. I noticed FloatValuesComparatorSource is referenced by FloatNestedSortingTests. Does it make sense to have a HalfFloatNestedSortingTests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-wagster Thanks for the feedback, addressed in ee021d7

case FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested);
case HALF_FLOAT -> new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested);
case DOUBLE -> new DoubleValuesComparatorSource(this, missingValue, sortMode, nested);
case DATE -> dateComparatorSource(missingValue, sortMode, nested);
case DATE_NANOSECONDS -> dateNanosComparatorSource(missingValue, sortMode, nested);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
*/
public class FloatValuesComparatorSource extends IndexFieldData.XFieldComparatorSource {

private final IndexNumericFieldData indexFieldData;
final IndexNumericFieldData indexFieldData;

public FloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
Expand All @@ -54,7 +54,7 @@ public SortField.Type reducedType() {
return SortField.Type.FLOAT;
}

private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException {
NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException {
final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues();
if (nested == null) {
return FieldData.replaceMissing(sortMode.select(values), missingValue);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.sandbox.document.HalfFloatPoint;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.comparators.NumericComparator;
import org.apache.lucene.util.BitUtil;

import java.io.IOException;

/**
* Comparator for hal_float values.
* This comparator provides a skipping functionality – an iterator that can skip over non-competitive documents.
*/
public class HalfFloatComparator extends NumericComparator<Float> {
private final float[] values;
protected float topValue;
protected float bottom;

public HalfFloatComparator(int numHits, String field, Float missingValue, boolean reverse, Pruning pruning) {
super(field, missingValue != null ? missingValue : 0.0f, reverse, pruning, HalfFloatPoint.BYTES);
values = new float[numHits];
}

@Override
public int compare(int slot1, int slot2) {
return Float.compare(values[slot1], values[slot2]);
}

@Override
public void setTopValue(Float value) {
super.setTopValue(value);
topValue = value;
}

@Override
public Float value(int slot) {
return Float.valueOf(values[slot]);
}

@Override
protected long missingValueAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(missingValue);
}

@Override
protected long sortableBytesToLong(byte[] bytes) {
// Copied form HalfFloatPoint::sortableBytesToShort
short x = (short) BitUtil.VH_BE_SHORT.get(bytes, 0);
// Re-flip the sign bit to restore the original value:
return (short) (x ^ 0x8000);
}

@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new HalfFloatLeafComparator(context);
}

/** Leaf comparator for {@link HalfFloatComparator} that provides skipping functionality */
public class HalfFloatLeafComparator extends NumericLeafComparator {

public HalfFloatLeafComparator(LeafReaderContext context) throws IOException {
super(context);
}

private float getValueForDoc(int doc) throws IOException {
if (docValues.advanceExact(doc)) {
return Float.intBitsToFloat((int) docValues.longValue());
} else {
return missingValue;
}
}

@Override
public void setBottom(int slot) throws IOException {
bottom = values[slot];
super.setBottom(slot);
}

@Override
public int compareBottom(int doc) throws IOException {
return Float.compare(bottom, getValueForDoc(doc));
}

@Override
public int compareTop(int doc) throws IOException {
return Float.compare(topValue, getValueForDoc(doc));
}

@Override
public void copy(int slot, int doc) throws IOException {
values[slot] = getValueForDoc(doc);
super.copy(slot, doc);
}

@Override
protected long bottomAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(bottom);
}

@Override
protected long topAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(topValue);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.search.MultiValueMode;

import java.io.IOException;

/**
* Comparator source for half_float values.
*/
public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource {
public HalfFloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
@Nullable Object missingValue,
MultiValueMode sortMode,
Nested nested
) {
super(indexFieldData, missingValue, sortMode, nested);
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final float fMissingValue = (Float) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new HalfFloatLeafComparator(context) {
@Override
protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException {
return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues();
}
};
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.sort.ShardDocSortField;
Expand Down Expand Up @@ -627,7 +628,7 @@ private static Tuple<SortField, SortField> randomSortFieldCustomComparatorSource
IndexFieldData.XFieldComparatorSource comparatorSource;
boolean reverse = randomBoolean();
Object missingValue = null;
switch (randomIntBetween(0, 3)) {
switch (randomIntBetween(0, 4)) {
case 0 -> comparatorSource = new LongValuesComparatorSource(
null,
randomBoolean() ? randomLong() : null,
Expand All @@ -647,7 +648,13 @@ private static Tuple<SortField, SortField> randomSortFieldCustomComparatorSource
randomFrom(MultiValueMode.values()),
null
);
case 3 -> {
case 3 -> comparatorSource = new HalfFloatValuesComparatorSource(
null,
randomBoolean() ? randomFloat() : null,
randomFrom(MultiValueMode.values()),
null
);
case 4 -> {
comparatorSource = new BytesRefFieldComparatorSource(
null,
randomBoolean() ? "_first" : "_last",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ public <IFD extends IndexFieldData<?>> IFD getForField(String type, String field
null,
null
).docValues(docValues).build(context).fieldType();
} else if (type.equals("half_float")) {
fieldType = new NumberFieldMapper.Builder(
fieldName,
NumberFieldMapper.NumberType.HALF_FLOAT,
ScriptCompiler.NONE,
false,
true,
IndexVersion.current(),
null,
null
).docValues(docValues).build(context).fieldType();
} else if (type.equals("double")) {
fieldType = new NumberFieldMapper.Builder(
fieldName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,13 @@

import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.search.join.ToParentBlockJoinQuery;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource;
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.search.MultiValueMode;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class FloatNestedSortingTests extends DoubleNestedSortingTests {

@Override
Expand All @@ -55,39 +39,4 @@ protected IndexFieldData.XFieldComparatorSource createFieldComparator(
protected IndexableField createField(String name, int value) {
return new SortedNumericDocValuesField(name, NumericUtils.floatToSortableInt(value));
}

protected void assertAvgScoreMode(
Query parentFilter,
IndexSearcher searcher,
IndexFieldData.XFieldComparatorSource innerFieldComparator
) throws IOException {
MultiValueMode sortMode = MultiValueMode.AVG;
Query childFilter = Queries.not(parentFilter);
XFieldComparatorSource nestedComparatorSource = createFieldComparator(
"field2",
sortMode,
-127,
createNested(searcher, parentFilter, childFilter)
);
Query query = new ToParentBlockJoinQuery(
new ConstantScoreQuery(childFilter),
new QueryBitSetProducer(parentFilter),
ScoreMode.None
);
Sort sort = new Sort(new SortField("field2", nestedComparatorSource));
TopDocs topDocs = searcher.search(query, 5, sort);
assertThat(topDocs.totalHits.value(), equalTo(7L));
assertThat(topDocs.scoreDocs.length, equalTo(5));
assertThat(topDocs.scoreDocs[0].doc, equalTo(11));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]).intValue(), equalTo(2));
assertThat(topDocs.scoreDocs[1].doc, equalTo(7));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[1]).fields[0]).intValue(), equalTo(2));
assertThat(topDocs.scoreDocs[2].doc, equalTo(3));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[2]).fields[0]).intValue(), equalTo(3));
assertThat(topDocs.scoreDocs[3].doc, equalTo(15));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[3]).fields[0]).intValue(), equalTo(3));
assertThat(topDocs.scoreDocs[4].doc, equalTo(19));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[4]).fields[0]).intValue(), equalTo(3));
}

}
Loading
Loading