diff --git a/docs/changelog/123074.yaml b/docs/changelog/123074.yaml
new file mode 100644
index 0000000000000..59ca1524893f8
--- /dev/null
+++ b/docs/changelog/123074.yaml
@@ -0,0 +1,5 @@
+pr: 123074
+summary: Adding ES|QL Reranker command in snapshot builds
+area: Ranking
+type: feature
+issues: []
diff --git a/muted-tests.yml b/muted-tests.yml
index c6274fc919768..70e4e3e96f139 100644
--- a/muted-tests.yml
+++ b/muted-tests.yml
@@ -369,6 +369,9 @@ tests:
- class: org.elasticsearch.snapshots.SharedClusterSnapshotRestoreIT
method: testDeletionOfFailingToRecoverIndexShouldStopRestore
issue: https://github.com/elastic/elasticsearch/issues/126204
+- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
+ method: testSimpleCircuitBreaking
+ issue: https://github.com/elastic/elasticsearch/issues/124337
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
method: testSchedulerCloseWaitsForRunningMerge
issue: https://github.com/elastic/elasticsearch/issues/125236
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index 7a14185e7b5dc..64957328d48dd 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -64,6 +64,10 @@ public static class Request extends BaseInferenceActionRequest {
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");
+ public static Builder builder(String inferenceEntityId, TaskType taskType) {
+ return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
+ }
+
static final ObjectParser PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
diff --git a/x-pack/plugin/esql/gen/EsqlBaseLexer.java b/x-pack/plugin/esql/gen/EsqlBaseLexer.java
new file mode 100644
index 0000000000000..3c465373ece7a
--- /dev/null
+++ b/x-pack/plugin/esql/gen/EsqlBaseLexer.java
@@ -0,0 +1,153 @@
+// Generated from /Users/afoucret/git/elasticsearch/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4 by ANTLR 4.13.2
+
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+import org.antlr.v4.runtime.Lexer;
+import org.antlr.v4.runtime.CharStream;
+import org.antlr.v4.runtime.Token;
+import org.antlr.v4.runtime.TokenStream;
+import org.antlr.v4.runtime.*;
+import org.antlr.v4.runtime.atn.*;
+import org.antlr.v4.runtime.dfa.DFA;
+import org.antlr.v4.runtime.misc.*;
+
+@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
+public class EsqlBaseLexer extends LexerConfig {
+ static { RuntimeMetaData.checkVersion("4.13.2", RuntimeMetaData.VERSION); }
+
+ protected static final DFA[] _decisionToDFA;
+ protected static final PredictionContextCache _sharedContextCache =
+ new PredictionContextCache();
+ public static final int
+ LINE_COMMENT=1, MULTILINE_COMMENT=2, WS=3;
+ public static String[] channelNames = {
+ "DEFAULT_TOKEN_CHANNEL", "HIDDEN"
+ };
+
+ public static String[] modeNames = {
+ "DEFAULT_MODE"
+ };
+
+ private static String[] makeRuleNames() {
+ return new String[] {
+ "LINE_COMMENT", "MULTILINE_COMMENT", "WS"
+ };
+ }
+ public static final String[] ruleNames = makeRuleNames();
+
+ private static String[] makeLiteralNames() {
+ return new String[] {
+ };
+ }
+ private static final String[] _LITERAL_NAMES = makeLiteralNames();
+ private static String[] makeSymbolicNames() {
+ return new String[] {
+ null, "LINE_COMMENT", "MULTILINE_COMMENT", "WS"
+ };
+ }
+ private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
+ public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
+
+ /**
+ * @deprecated Use {@link #VOCABULARY} instead.
+ */
+ @Deprecated
+ public static final String[] tokenNames;
+ static {
+ tokenNames = new String[_SYMBOLIC_NAMES.length];
+ for (int i = 0; i < tokenNames.length; i++) {
+ tokenNames[i] = VOCABULARY.getLiteralName(i);
+ if (tokenNames[i] == null) {
+ tokenNames[i] = VOCABULARY.getSymbolicName(i);
+ }
+
+ if (tokenNames[i] == null) {
+ tokenNames[i] = "";
+ }
+ }
+ }
+
+ @Override
+ @Deprecated
+ public String[] getTokenNames() {
+ return tokenNames;
+ }
+
+ @Override
+
+ public Vocabulary getVocabulary() {
+ return VOCABULARY;
+ }
+
+
+ public EsqlBaseLexer(CharStream input) {
+ super(input);
+ _interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
+ }
+
+ @Override
+ public String getGrammarFileName() { return "EsqlBaseLexer.g4"; }
+
+ @Override
+ public String[] getRuleNames() { return ruleNames; }
+
+ @Override
+ public String getSerializedATN() { return _serializedATN; }
+
+ @Override
+ public String[] getChannelNames() { return channelNames; }
+
+ @Override
+ public String[] getModeNames() { return modeNames; }
+
+ @Override
+ public ATN getATN() { return _ATN; }
+
+ public static final String _serializedATN =
+ "\u0004\u0000\u0003.\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
+ "\u0007\u0001\u0002\u0002\u0007\u0002\u0001\u0000\u0001\u0000\u0001\u0000"+
+ "\u0001\u0000\u0005\u0000\f\b\u0000\n\u0000\f\u0000\u000f\t\u0000\u0001"+
+ "\u0000\u0003\u0000\u0012\b\u0000\u0001\u0000\u0003\u0000\u0015\b\u0000"+
+ "\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"+
+ "\u0001\u0001\u0005\u0001\u001e\b\u0001\n\u0001\f\u0001!\t\u0001\u0001"+
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0002\u0004"+
+ "\u0002)\b\u0002\u000b\u0002\f\u0002*\u0001\u0002\u0001\u0002\u0001\u001f"+
+ "\u0000\u0003\u0001\u0001\u0003\u0002\u0005\u0003\u0001\u0000\u0002\u0002"+
+ "\u0000\n\n\r\r\u0003\u0000\t\n\r\r 3\u0000\u0001\u0001\u0000\u0000\u0000"+
+ "\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005\u0001\u0000\u0000\u0000"+
+ "\u0001\u0007\u0001\u0000\u0000\u0000\u0003\u0018\u0001\u0000\u0000\u0000"+
+ "\u0005(\u0001\u0000\u0000\u0000\u0007\b\u0005/\u0000\u0000\b\t\u0005/"+
+ "\u0000\u0000\t\r\u0001\u0000\u0000\u0000\n\f\b\u0000\u0000\u0000\u000b"+
+ "\n\u0001\u0000\u0000\u0000\f\u000f\u0001\u0000\u0000\u0000\r\u000b\u0001"+
+ "\u0000\u0000\u0000\r\u000e\u0001\u0000\u0000\u0000\u000e\u0011\u0001\u0000"+
+ "\u0000\u0000\u000f\r\u0001\u0000\u0000\u0000\u0010\u0012\u0005\r\u0000"+
+ "\u0000\u0011\u0010\u0001\u0000\u0000\u0000\u0011\u0012\u0001\u0000\u0000"+
+ "\u0000\u0012\u0014\u0001\u0000\u0000\u0000\u0013\u0015\u0005\n\u0000\u0000"+
+ "\u0014\u0013\u0001\u0000\u0000\u0000\u0014\u0015\u0001\u0000\u0000\u0000"+
+ "\u0015\u0016\u0001\u0000\u0000\u0000\u0016\u0017\u0006\u0000\u0000\u0000"+
+ "\u0017\u0002\u0001\u0000\u0000\u0000\u0018\u0019\u0005/\u0000\u0000\u0019"+
+ "\u001a\u0005*\u0000\u0000\u001a\u001f\u0001\u0000\u0000\u0000\u001b\u001e"+
+ "\u0003\u0003\u0001\u0000\u001c\u001e\t\u0000\u0000\u0000\u001d\u001b\u0001"+
+ "\u0000\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e!\u0001\u0000"+
+ "\u0000\u0000\u001f \u0001\u0000\u0000\u0000\u001f\u001d\u0001\u0000\u0000"+
+ "\u0000 \"\u0001\u0000\u0000\u0000!\u001f\u0001\u0000\u0000\u0000\"#\u0005"+
+ "*\u0000\u0000#$\u0005/\u0000\u0000$%\u0001\u0000\u0000\u0000%&\u0006\u0001"+
+ "\u0000\u0000&\u0004\u0001\u0000\u0000\u0000\')\u0007\u0001\u0000\u0000"+
+ "(\'\u0001\u0000\u0000\u0000)*\u0001\u0000\u0000\u0000*(\u0001\u0000\u0000"+
+ "\u0000*+\u0001\u0000\u0000\u0000+,\u0001\u0000\u0000\u0000,-\u0006\u0002"+
+ "\u0000\u0000-\u0006\u0001\u0000\u0000\u0000\u0007\u0000\r\u0011\u0014"+
+ "\u001d\u001f*\u0001\u0000\u0001\u0000";
+ public static final ATN _ATN =
+ new ATNDeserializer().deserialize(_serializedATN.toCharArray());
+ static {
+ _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
+ for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
+ _decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
+ }
+ }
+}
\ No newline at end of file
diff --git a/x-pack/plugin/esql/gen/EsqlBaseLexer.tokens b/x-pack/plugin/esql/gen/EsqlBaseLexer.tokens
new file mode 100644
index 0000000000000..907d464c07573
--- /dev/null
+++ b/x-pack/plugin/esql/gen/EsqlBaseLexer.tokens
@@ -0,0 +1,3 @@
+LINE_COMMENT=1
+MULTILINE_COMMENT=2
+WS=3
diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java
index c774fb5955133..37b818d222548 100644
--- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java
+++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java
@@ -52,6 +52,7 @@
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V12;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_PLANNING_V1;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METADATA_FIELDS_REMOTE_TEST;
+import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.UNMAPPED_FIELDS;
import static org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase.Mode.SYNC;
import static org.mockito.ArgumentMatchers.any;
@@ -130,6 +131,8 @@ protected void shouldSkipTest(String testName) throws IOException {
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V12.capabilityName()));
// Unmapped fields require a coorect capability response from every cluster, which isn't currently implemented.
assumeFalse("UNMAPPED FIELDS not yet supported in CCS", testCase.requiredCapabilities.contains(UNMAPPED_FIELDS.capabilityName()));
+ // Need to do additional developmnet to get CSS support for the rerank coammnd
+ assumeFalse("RERANK not yet supported in CCS", testCase.requiredCapabilities.contains(RERANK.capabilityName()));
}
@Override
diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
index fb06d31ddafac..afd7d95196ab8 100644
--- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
+++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
@@ -66,8 +66,11 @@
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
+import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
+import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
+import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND;
@@ -134,6 +137,10 @@ public void setup() throws IOException {
createInferenceEndpoint(client());
}
+ if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
+ createRerankInferenceEndpoint(client());
+ }
+
boolean supportsLookup = supportsIndexModeLookup();
boolean supportsSourceMapping = supportsSourceFieldMapping();
if (indexExists(availableDatasetsForEs(client(), supportsLookup, supportsSourceMapping).iterator().next().indexName()) == false) {
@@ -153,6 +160,7 @@ public static void wipeTestData() throws IOException {
}
deleteInferenceEndpoint(client());
+ deleteRerankInferenceEndpoint(client());
}
public boolean logResults() {
diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestRerankTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestRerankTestCase.java
new file mode 100644
index 0000000000000..ce5e58d61fbb3
--- /dev/null
+++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestRerankTestCase.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.qa.rest;
+
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
+import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
+import static org.hamcrest.core.StringContains.containsString;
+
+public class RestRerankTestCase extends ESRestTestCase {
+
+ @Before
+ public void skipWhenRerankDisabled() throws IOException {
+ assumeTrue(
+ "Requires RERANK capability",
+ EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.RERANK.capabilityName()))
+ );
+ }
+
+ @Before
+ @After
+ public void assertRequestBreakerEmpty() throws Exception {
+ EsqlSpecTestCase.assertRequestBreakerEmpty();
+ }
+
+ @Before
+ public void setUpInferenceEndpoint() throws IOException {
+ createRerankInferenceEndpoint(adminClient());
+ }
+
+ @Before
+ public void setUpTestIndex() throws IOException {
+ Request request = new Request("PUT", "/rerank-test-index");
+ request.setJsonEntity("""
+ {
+ "mappings": {
+ "properties": {
+ "title": { "type": "text" },
+ "author": { "type": "text" }
+ }
+ }
+ }""");
+ assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
+
+ request = new Request("POST", "/rerank-test-index/_bulk");
+ request.addParameter("refresh", "true");
+ request.setJsonEntity("""
+ { "index": {"_id": 1} }
+ { "title": "The Future of Exploration", "author": "John Doe" }
+ { "index": {"_id": 2} }
+ { "title": "Deep Sea Exploration", "author": "Jane Smith" }
+ { "index": {"_id": 3} }
+ { "title": "History of Space Exploration", "author": "Alice Johnson" }
+ """);
+ assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
+ }
+
+ @After
+ public void wipeData() throws IOException {
+ try {
+ adminClient().performRequest(new Request("DELETE", "/rerank-test-index"));
+ } catch (ResponseException e) {
+ // 404 here just means we had no indexes
+ if (e.getResponse().getStatusLine().getStatusCode() != 404) {
+ throw e;
+ }
+ }
+
+ deleteRerankInferenceEndpoint(adminClient());
+ }
+
+ public void testRerankWithSingleField() throws IOException {
+ String query = """
+ FROM rerank-test-index
+ | WHERE match(title, "exploration")
+ | RERANK "exploration" ON title WITH test_reranker
+ | EVAL _score = ROUND(_score, 5)
+ """;
+
+ Map result = runEsqlQuery(query);
+
+ var expectedValues = List.of(
+ List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
+ List.of("John Doe", "The Future of Exploration", 0.02632d),
+ List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
+ );
+
+ assertResultMap(result, defaultOutputColumns(), expectedValues);
+ }
+
+ public void testRerankWithMultipleFields() throws IOException {
+ String query = """
+ FROM rerank-test-index
+ | WHERE match(title, "exploration")
+ | RERANK "exploration" ON title, author WITH test_reranker
+ | EVAL _score = ROUND(_score, 5)
+ """;
+
+ Map result = runEsqlQuery(query);
+ ;
+ var expectedValues = List.of(
+ List.of("Jane Smith", "Deep Sea Exploration", 0.01818d),
+ List.of("John Doe", "The Future of Exploration", 0.01754d),
+ List.of("Alice Johnson", "History of Space Exploration", 0.01515d)
+ );
+
+ assertResultMap(result, defaultOutputColumns(), expectedValues);
+ }
+
+ public void testRerankWithPositionalParams() throws IOException {
+ String query = """
+ FROM rerank-test-index
+ | WHERE match(title, "exploration")
+ | RERANK ? ON title WITH ?
+ | EVAL _score = ROUND(_score, 5)
+ """;
+
+ Map result = runEsqlQuery(query, "[\"exploration\", \"test_reranker\"]");
+
+ var expectedValues = List.of(
+ List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
+ List.of("John Doe", "The Future of Exploration", 0.02632d),
+ List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
+ );
+
+ assertResultMap(result, defaultOutputColumns(), expectedValues);
+ }
+
+ public void testRerankWithNamedParams() throws IOException {
+ String query = """
+ FROM rerank-test-index
+ | WHERE match(title, ?queryText)
+ | RERANK ?queryText ON title WITH ?inferenceId
+ | EVAL _score = ROUND(_score, 5)
+ """;
+
+ Map result = runEsqlQuery(query, "[{\"queryText\": \"exploration\"}, {\"inferenceId\": \"test_reranker\"}]");
+
+ var expectedValues = List.of(
+ List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
+ List.of("John Doe", "The Future of Exploration", 0.02632d),
+ List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
+ );
+
+ assertResultMap(result, defaultOutputColumns(), expectedValues);
+ }
+
+ public void testRerankWithMissingInferenceId() {
+ String query = """
+ FROM rerank-test-index
+ | WHERE match(title, "exploration")
+ | RERANK "exploration" ON title WITH test_missing
+ | EVAL _score = ROUND(_score, 5)
+ """;
+
+ ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));
+ assertThat(re.getMessage(), containsString("Inference endpoint not found"));
+ }
+
+ private static List
*/
@Override public T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index 0c35df036215a..f7812d81a28f1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -635,6 +635,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
+ * @param ctx the parse tree
+ */
+ void enterRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
+ * @param ctx the parse tree
+ */
+ void exitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
/**
* Enter a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index c21618dbcbf53..dcae327203aca 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -388,6 +388,12 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
/**
* Visit a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
index c54073004c365..d6eddf1883489 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
@@ -68,6 +68,7 @@
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@@ -707,4 +708,56 @@ public PlanFactory visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) {
return new OrderBy(source, dedup, order);
};
}
+
+ @Override
+ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
+ var source = source(ctx);
+
+ if (false == EsqlCapabilities.Cap.RERANK.isEnabled()) {
+ throw new ParsingException(source, "RERANK is in preview and only available in SNAPSHOT build");
+ }
+
+ Expression queryText = expression(ctx.queryText);
+ if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
+ if (queryTextLiteral.value() == null) {
+ throw new ParsingException(
+ source(ctx.queryText),
+ "Query text cannot be null or undefined in RERANK",
+ ctx.queryText.getText()
+ );
+ }
+ } else {
+ throw new ParsingException(
+ source(ctx.queryText),
+ "RERANK only support string as query text but [{}] cannot be used as string",
+ ctx.queryText.getText()
+ );
+ }
+
+ return p -> new Rerank(source, p, inferenceId(ctx.inferenceId), queryText, visitFields(ctx.fields()));
+ }
+
+ public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
+ if (ctx.identifier() != null) {
+ return new Literal(source(ctx), visitIdentifier(ctx.identifier()), KEYWORD);
+ }
+
+ if (expression(ctx.parameter()) instanceof Literal literalParam) {
+ if (literalParam.value() != null) {
+ return literalParam;
+ }
+
+ throw new ParsingException(
+ source(ctx.parameter()),
+ "Query parameter [{}] is null or undefined and cannot be used as inference id",
+ ctx.parameter().getText()
+ );
+ }
+
+ throw new ParsingException(
+ source(ctx.parameter()),
+ "Query parameter [{}] is not a string and cannot be used as inference id",
+ ctx.parameter().getText()
+ );
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
index 2b898c7cdbed8..292c7044be03a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
@@ -23,6 +23,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@@ -49,6 +50,7 @@
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import java.util.ArrayList;
import java.util.List;
@@ -81,6 +83,7 @@ public static List logical() {
MvExpand.ENTRY,
OrderBy.ENTRY,
Project.ENTRY,
+ Rerank.ENTRY,
TimeSeriesAggregate.ENTRY,
TopN.ENTRY
);
@@ -106,6 +109,7 @@ public static List physical() {
LocalSourceExec.ENTRY,
MvExpandExec.ENTRY,
ProjectExec.ENTRY,
+ RerankExec.ENTRY,
ShowExec.ENTRY,
SubqueryExec.ENTRY,
TimeSeriesAggregateExec.ENTRY,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
new file mode 100644
index 0000000000000..9f12f258dd8ed
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.logical.inference;
+
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public abstract class InferencePlan extends UnaryPlan {
+
+ private final Expression inferenceId;
+
+ protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
+ super(source, child);
+ this.inferenceId = inferenceId;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ Source.EMPTY.writeTo(out);
+ out.writeNamedWriteable(child());
+ out.writeNamedWriteable(inferenceId());
+ }
+
+ public Expression inferenceId() {
+ return inferenceId;
+ }
+
+ @Override
+ public boolean expressionsResolved() {
+ return inferenceId.resolved();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ InferencePlan other = (InferencePlan) o;
+ return Objects.equals(inferenceId(), other.inferenceId());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), inferenceId());
+ }
+
+ public abstract TaskType taskType();
+
+ public abstract InferencePlan withInferenceId(Expression newInferenceId);
+
+ public InferencePlan withInferenceResolutionError(String inferenceId, String error) {
+ return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
new file mode 100644
index 0000000000000..c9cec8e083833
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java
@@ -0,0 +1,190 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.logical.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
+import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
+import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.Order;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.QueryPlan;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
+import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
+import org.elasticsearch.xpack.esql.plan.logical.SurrogateLogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
+import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
+
+public class Rerank extends InferencePlan implements SortAgnostic, SurrogateLogicalPlan {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
+ private final Attribute scoreAttribute;
+ private final Expression queryText;
+ private final List rerankFields;
+ private List lazyOutput;
+
+ public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List rerankFields) {
+ super(source, child, inferenceId);
+ this.queryText = queryText;
+ this.rerankFields = rerankFields;
+ this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
+ }
+
+ public Rerank(
+ Source source,
+ LogicalPlan child,
+ Expression inferenceId,
+ Expression queryText,
+ List rerankFields,
+ Attribute scoreAttribute
+ ) {
+ super(source, child, inferenceId);
+ this.queryText = queryText;
+ this.rerankFields = rerankFields;
+ this.scoreAttribute = scoreAttribute;
+ }
+
+ public Rerank(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(LogicalPlan.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readCollectionAsList(Alias::new),
+ in.readNamedWriteable(Attribute.class)
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeNamedWriteable(queryText);
+ out.writeCollection(rerankFields());
+ out.writeNamedWriteable(scoreAttribute);
+ }
+
+ public Expression queryText() {
+ return queryText;
+ }
+
+ public List rerankFields() {
+ return rerankFields;
+ }
+
+ public Attribute scoreAttribute() {
+ return scoreAttribute;
+ }
+
+ @Override
+ public TaskType taskType() {
+ return TaskType.RERANK;
+ }
+
+ @Override
+ public Rerank withInferenceId(Expression newInferenceId) {
+ return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
+ }
+
+ public Rerank withRerankFields(List newRerankFields) {
+ return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
+ }
+
+ public Rerank withScoreAttribute(Attribute newScoreAttribute) {
+ return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ public UnaryPlan replaceChild(LogicalPlan newChild) {
+ return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
+ }
+
+ @Override
+ protected AttributeSet computeReferences() {
+ AttributeSet.Builder refs = computeReferences(rerankFields).asBuilder();
+
+ if (planHasAttribute(child(), scoreAttribute)) {
+ refs.add(scoreAttribute);
+ }
+
+ return refs.build();
+ }
+
+ public static AttributeSet computeReferences(List fields) {
+ AttributeSet rerankFields = AttributeSet.of(asAttributes(fields));
+ return Expressions.references(fields).subtract(rerankFields);
+ }
+
+ @Override
+ public boolean expressionsResolved() {
+ return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
+ }
+
+ @Override
+ protected NodeInfo extends LogicalPlan> info() {
+ return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ Rerank rerank = (Rerank) o;
+ return Objects.equals(queryText, rerank.queryText)
+ && Objects.equals(rerankFields, rerank.rerankFields)
+ && Objects.equals(scoreAttribute, rerank.scoreAttribute);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
+ }
+
+ @Override
+ public LogicalPlan surrogate() {
+ Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
+ return new OrderBy(source(), this, List.of(sortOrder));
+ }
+
+ @Override
+ public List output() {
+ if (lazyOutput == null) {
+ lazyOutput = planHasAttribute(child(), scoreAttribute)
+ ? child().output()
+ : mergeOutputAttributes(List.of(scoreAttribute), child().output());
+ }
+
+ return lazyOutput;
+ }
+
+ public static boolean planHasAttribute(QueryPlan> plan, Attribute attribute) {
+ return plan.outputSet().stream().anyMatch(attr -> attr.equals(attribute));
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
new file mode 100644
index 0000000000000..7954690a0fdc0
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/InferenceExec.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference;
+
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public abstract class InferenceExec extends UnaryExec {
+ private final Expression inferenceId;
+
+ protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
+ super(source, child);
+ this.inferenceId = inferenceId;
+ }
+
+ public Expression inferenceId() {
+ return inferenceId;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ Source.EMPTY.writeTo(out);
+ out.writeNamedWriteable(child());
+ out.writeNamedWriteable(inferenceId());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ InferenceExec that = (InferenceExec) o;
+ return inferenceId.equals(that.inferenceId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), inferenceId());
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
new file mode 100644
index 0000000000000..4570775af2ed1
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExec.java
@@ -0,0 +1,138 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
+import static org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.planHasAttribute;
+
+public class RerankExec extends InferenceExec {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ PhysicalPlan.class,
+ "RerankExec",
+ RerankExec::new
+ );
+
+ private final Expression queryText;
+ private final List rerankFields;
+ private final Attribute scoreAttribute;
+
+ public RerankExec(
+ Source source,
+ PhysicalPlan child,
+ Expression inferenceId,
+ Expression queryText,
+ List rerankFields,
+ Attribute scoreAttribute
+ ) {
+ super(source, child, inferenceId);
+ this.queryText = queryText;
+ this.rerankFields = rerankFields;
+ this.scoreAttribute = scoreAttribute;
+ }
+
+ public RerankExec(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(PhysicalPlan.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readCollectionAsList(Alias::new),
+ in.readNamedWriteable(Attribute.class)
+ );
+ }
+
+ public Expression queryText() {
+ return queryText;
+ }
+
+ public List rerankFields() {
+ return rerankFields;
+ }
+
+ public Attribute scoreAttribute() {
+ return scoreAttribute;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeNamedWriteable(queryText());
+ out.writeCollection(rerankFields());
+ out.writeNamedWriteable(scoreAttribute);
+ }
+
+ @Override
+ protected NodeInfo extends PhysicalPlan> info() {
+ return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
+ }
+
+ @Override
+ public UnaryExec replaceChild(PhysicalPlan newChild) {
+ return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
+ }
+
+ @Override
+ public List output() {
+ if (planHasAttribute(child(), scoreAttribute)) {
+ return child().output();
+ }
+
+ return mergeOutputAttributes(List.of(scoreAttribute), child().output());
+ }
+
+ @Override
+ protected AttributeSet computeReferences() {
+ AttributeSet.Builder refs = Rerank.computeReferences(rerankFields).asBuilder();
+
+ if (planHasAttribute(child(), scoreAttribute)) {
+ refs.add(scoreAttribute);
+ }
+
+ return refs.build();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ RerankExec rerank = (RerankExec) o;
+ return Objects.equals(queryText, rerank.queryText)
+ && Objects.equals(rerankFields, rerank.rerankFields)
+ && Objects.equals(scoreAttribute, rerank.scoreAttribute);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index ada28c2790c39..1e0c51b8fcab3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.planner;
import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.Describable;
@@ -23,6 +24,7 @@
import org.elasticsearch.compute.operator.ColumnLoadOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
import org.elasticsearch.compute.operator.LimitOperator;
@@ -57,6 +59,7 @@
import org.elasticsearch.node.Node;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -78,6 +81,9 @@
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
import org.elasticsearch.xpack.esql.expression.Order;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.RerankOperator;
+import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.ChangePointExec;
@@ -104,6 +110,7 @@
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.score.ScoreMapper;
@@ -112,6 +119,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -144,6 +152,7 @@ public class LocalExecutionPlanner {
private final Supplier exchangeSinkSupplier;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
+ private final InferenceRunner inferenceRunner;
private final PhysicalOperationProviders physicalOperationProviders;
private final List shardContexts;
@@ -159,6 +168,7 @@ public LocalExecutionPlanner(
Supplier exchangeSinkSupplier,
EnrichLookupService enrichLookupService,
LookupFromIndexService lookupFromIndexService,
+ InferenceRunner inferenceRunner,
PhysicalOperationProviders physicalOperationProviders,
List shardContexts
) {
@@ -174,6 +184,7 @@ public LocalExecutionPlanner(
this.exchangeSinkSupplier = exchangeSinkSupplier;
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
+ this.inferenceRunner = inferenceRunner;
this.physicalOperationProviders = physicalOperationProviders;
this.shardContexts = shardContexts;
}
@@ -242,6 +253,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c
return planLimit(limit, context);
} else if (node instanceof MvExpandExec mvExpand) {
return planMvExpand(mvExpand, context);
+ } else if (node instanceof RerankExec rerank) {
+ return planRerank(rerank, context);
} else if (node instanceof ChangePointExec changePoint) {
return planChangePoint(changePoint, context);
}
@@ -543,6 +556,36 @@ private PhysicalOperation planEnrich(EnrichExec enrich, LocalExecutionPlannerCon
);
}
+ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerContext context) {
+ PhysicalOperation source = plan(rerank.child(), context);
+
+ Map rerankFieldsEvaluatorSuppliers = new LinkedHashMap<>();
+
+ for (var rerankField : rerank.rerankFields()) {
+ rerankFieldsEvaluatorSuppliers.put(
+ new ColumnInfoImpl(rerankField.name(), rerankField.dataType(), null),
+ EvalMapper.toEvaluator(context.foldCtx(), rerankField.child(), source.layout)
+ );
+ }
+
+ XContentRowEncoder.Factory rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
+
+ String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
+ String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
+
+ Layout outputLayout = source.layout;
+ if (source.layout.get(rerank.scoreAttribute().id()) == null) {
+ outputLayout = source.layout.builder().append(rerank.scoreAttribute()).build();
+ }
+
+ int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
+
+ return source.with(
+ new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel),
+ outputLayout
+ );
+ }
+
private PhysicalOperation planHashJoin(HashJoinExec join, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(join.left(), context);
int positionsChannel = source.layout.numberOfChannels();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
index 5c727dca8329e..af7db0f190c03 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
@@ -23,6 +23,7 @@
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
@@ -38,6 +39,7 @@
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import java.util.ArrayList;
import java.util.List;
@@ -173,6 +175,18 @@ private PhysicalPlan mapUnary(UnaryPlan unary) {
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
}
+ if (unary instanceof Rerank rerank) {
+ mappedChild = addExchangeForFragment(rerank, mappedChild);
+ return new RerankExec(
+ rerank.source(),
+ mappedChild,
+ rerank.inferenceId(),
+ rerank.queryText(),
+ rerank.rerankFields(),
+ rerank.scoreAttribute()
+ );
+ }
+
//
// Pipeline operators
//
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index 473ca4f92c8a8..6f44634f40ebb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@@ -42,6 +43,7 @@
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
import java.util.List;
@@ -86,6 +88,17 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
return new GrokExec(grok.source(), child, grok.input(), grok.parser(), grok.extractedFields());
}
+ if (p instanceof Rerank rerank) {
+ return new RerankExec(
+ rerank.source(),
+ child,
+ rerank.inferenceId(),
+ rerank.queryText(),
+ rerank.rerankFields(),
+ rerank.scoreAttribute()
+ );
+ }
+
if (p instanceof Enrich enrich) {
return new EnrichExec(
enrich.source(),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
index a9f00f3635703..5724587f0573b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
@@ -49,6 +49,7 @@
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
@@ -125,6 +126,7 @@ public class ComputeService {
private final DriverTaskRunner driverRunner;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
+ private final InferenceRunner inferenceRunner;
private final ClusterService clusterService;
private final AtomicLong childSessionIdGenerator = new AtomicLong();
private final DataNodeComputeHandler dataNodeComputeHandler;
@@ -133,25 +135,24 @@ public class ComputeService {
@SuppressWarnings("this-escape")
public ComputeService(
- SearchService searchService,
- TransportService transportService,
- ExchangeService exchangeService,
+ TransportActionServices transportActionServices,
EnrichLookupService enrichLookupService,
LookupFromIndexService lookupFromIndexService,
- ClusterService clusterService,
ThreadPool threadPool,
BigArrays bigArrays,
BlockFactory blockFactory
) {
- this.searchService = searchService;
- this.transportService = transportService;
+ this.searchService = transportActionServices.searchService();
+ this.transportService = transportActionServices.transportService();
+ this.exchangeService = transportActionServices.exchangeService();
this.bigArrays = bigArrays.withCircuitBreaking();
this.blockFactory = blockFactory;
var esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH);
this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor);
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
- this.clusterService = clusterService;
+ this.inferenceRunner = transportActionServices.inferenceRunner();
+ this.clusterService = transportActionServices.clusterService();
this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor);
this.clusterComputeHandler = new ClusterComputeHandler(
this,
@@ -160,7 +161,6 @@ public ComputeService(
esqlExecutor,
dataNodeComputeHandler
);
- this.exchangeService = exchangeService;
}
public void execute(
@@ -428,6 +428,7 @@ public SourceProvider createSourceProvider() {
context.exchangeSinkSupplier(),
enrichLookupService,
lookupFromIndexService,
+ inferenceRunner,
new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()),
contexts
);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
index ad112542e000a..0874ff4068227 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
@@ -13,6 +13,7 @@
import org.elasticsearch.search.SearchService;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.usage.UsageService;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
public record TransportActionServices(
TransportService transportService,
@@ -20,5 +21,6 @@ public record TransportActionServices(
ExchangeService exchangeService,
ClusterService clusterService,
IndexNameExpressionResolver indexNameExpressionResolver,
- UsageService usageService
+ UsageService usageService,
+ InferenceRunner inferenceRunner
) {}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
index 72ca465f647b7..2b3a877b48205 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
@@ -49,6 +49,7 @@
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner;
import org.elasticsearch.xpack.esql.session.Result;
@@ -126,17 +127,7 @@ public TransportEsqlQueryAction(
bigArrays,
blockFactoryProvider.blockFactory()
);
- this.computeService = new ComputeService(
- searchService,
- transportService,
- exchangeService,
- enrichLookupService,
- lookupFromIndexService,
- clusterService,
- threadPool,
- bigArrays,
- blockFactoryProvider.blockFactory()
- );
+
this.asyncTaskManagementService = new AsyncTaskManagementService<>(
XPackPlugin.ASYNC_RESULTS_INDEX,
client,
@@ -159,8 +150,19 @@ public TransportEsqlQueryAction(
exchangeService,
clusterService,
indexNameExpressionResolver,
- usageService
+ usageService,
+ new InferenceRunner(client)
);
+
+ this.computeService = new ComputeService(
+ services,
+ enrichLookupService,
+ lookupFromIndexService,
+ threadPool,
+ bigArrays,
+ blockFactoryProvider.blockFactory()
+ );
+
defaultAllowPartialResults = EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS.get(clusterService.getSettings());
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS, v -> defaultAllowPartialResults = v);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
index def4e23af381f..b218435b03c36 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
@@ -50,6 +50,8 @@
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.index.MappingException;
+import org.elasticsearch.xpack.esql.inference.InferenceResolution;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
@@ -63,6 +65,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
+import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@@ -119,6 +122,7 @@ public interface PlanRunner {
private final PlanTelemetry planTelemetry;
private final IndicesExpressionGrouper indicesExpressionGrouper;
private Set configuredClusters;
+ private final InferenceRunner inferenceRunner;
public EsqlSession(
String sessionId,
@@ -146,6 +150,7 @@ public EsqlSession(
this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
this.planTelemetry = planTelemetry;
this.indicesExpressionGrouper = indicesExpressionGrouper;
+ this.inferenceRunner = services.inferenceRunner();
this.preMapper = new PreMapper(services);
}
@@ -335,7 +340,7 @@ public void analyzedPlan(
Function analyzeAction = (l) -> {
Analyzer analyzer = new Analyzer(
- new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution),
+ new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution, l.inferenceResolution),
verifier
);
LogicalPlan plan = analyzer.analyze(parsed);
@@ -367,7 +372,9 @@ public void analyzedPlan(
var listener = SubscribableListener.newForked(
l -> enrichPolicyResolver.resolvePolicies(targetClusters, unresolvedPolicies, l)
- ).andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l));
+ )
+ .andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
+ .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
// first resolve the lookup indices, then the main indices
for (var index : preAnalysis.lookupIndices) {
listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); });
@@ -580,6 +587,14 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric
}
}
+ private void resolveInferences(
+ List inferencePlans,
+ PreAnalysisResult preAnalysisResult,
+ ActionListener l
+ ) {
+ inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
+ }
+
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set enrichPolicyMatchFields, PreAnalysisResult result) {
if (false == parsed.anyMatch(plan -> plan instanceof Aggregate || plan instanceof Project)) {
// no explicit columns selection, for example "from employees"
@@ -746,18 +761,44 @@ record PreAnalysisResult(
Map lookupIndices,
EnrichResolution enrichResolution,
Set fieldNames,
- Set wildcardJoinIndices
+ Set wildcardJoinIndices,
+ InferenceResolution inferenceResolution
) {
PreAnalysisResult(EnrichResolution newEnrichResolution) {
- this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of());
+ this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of(), InferenceResolution.EMPTY);
}
PreAnalysisResult withEnrichResolution(EnrichResolution newEnrichResolution) {
- return new PreAnalysisResult(indices(), lookupIndices(), newEnrichResolution, fieldNames(), wildcardJoinIndices());
+ return new PreAnalysisResult(
+ indices(),
+ lookupIndices(),
+ newEnrichResolution,
+ fieldNames(),
+ wildcardJoinIndices(),
+ inferenceResolution()
+ );
+ }
+
+ PreAnalysisResult withInferenceResolution(InferenceResolution newInferenceResolution) {
+ return new PreAnalysisResult(
+ indices(),
+ lookupIndices(),
+ enrichResolution(),
+ fieldNames(),
+ wildcardJoinIndices(),
+ newInferenceResolution
+ );
}
PreAnalysisResult withIndexResolution(IndexResolution newIndexResolution) {
- return new PreAnalysisResult(newIndexResolution, lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices());
+ return new PreAnalysisResult(
+ newIndexResolution,
+ lookupIndices(),
+ enrichResolution(),
+ fieldNames(),
+ wildcardJoinIndices(),
+ inferenceResolution()
+ );
}
PreAnalysisResult addLookupIndexResolution(String index, IndexResolution newIndexResolution) {
@@ -766,11 +807,25 @@ PreAnalysisResult addLookupIndexResolution(String index, IndexResolution newInde
}
PreAnalysisResult withFieldNames(Set newFields) {
- return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), newFields, wildcardJoinIndices());
+ return new PreAnalysisResult(
+ indices(),
+ lookupIndices(),
+ enrichResolution(),
+ newFields,
+ wildcardJoinIndices(),
+ inferenceResolution()
+ );
}
public PreAnalysisResult withWildcardJoinIndices(Set wildcardJoinIndices) {
- return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices);
+ return new PreAnalysisResult(
+ indices(),
+ lookupIndices(),
+ enrichResolution(),
+ fieldNames(),
+ wildcardJoinIndices,
+ inferenceResolution()
+ );
}
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index be93fcc561c0c..189ec7774e468 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -66,6 +66,7 @@
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
@@ -119,6 +120,7 @@
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
@@ -261,6 +263,10 @@ public final void test() throws Throwable {
"enrich can't load fields in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.ENRICH_LOAD.capabilityName())
);
+ assumeFalse(
+ "can't use rereank in csv tests",
+ testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.RERANK.capabilityName())
+ );
assumeFalse(
"can't use match in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MATCH_OPERATOR_COLON.capabilityName())
@@ -478,7 +484,10 @@ private static EnrichPolicy loadEnrichPolicyMapping(String policyFileName) {
private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.MultiIndexTestDataset datasets) {
var indexResolution = loadIndexResolution(datasets);
var enrichPolicies = loadEnrichPolicies();
- var analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies), TEST_VERIFIER);
+ var analyzer = new Analyzer(
+ new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies, emptyInferenceResolution()),
+ TEST_VERIFIER
+ );
LogicalPlan plan = analyzer.analyze(parsed);
plan.setAnalyzed();
LOGGER.debug("Analyzed plan:\n{}", plan);
@@ -666,6 +675,7 @@ void executeSubPlan(
() -> exchangeSink.createExchangeSink(() -> {}),
Mockito.mock(EnrichLookupService.class),
Mockito.mock(LookupFromIndexService.class),
+ Mockito.mock(InferenceRunner.class),
physicalOperationProviders,
List.of()
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
index d4e786a9d9bb0..db9047be3f065 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
@@ -8,12 +8,15 @@
package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.index.IndexMode;
+import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
+import org.elasticsearch.xpack.esql.inference.InferenceResolution;
+import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.parser.QueryParams;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -29,6 +32,7 @@
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.RANGE_TYPE;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
public final class AnalyzerTestUtils {
@@ -57,7 +61,8 @@ public static Analyzer analyzer(IndexResolution indexResolution, Verifier verifi
new EsqlFunctionRegistry(),
indexResolution,
defaultLookupResolution(),
- defaultEnrichResolution()
+ defaultEnrichResolution(),
+ emptyInferenceResolution()
),
verifier
);
@@ -70,7 +75,8 @@ public static Analyzer analyzer(IndexResolution indexResolution, Map analyze(
+ "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `completion-inference-id`",
+ "mapping-books.json"
+ )
+
+ );
+ assertThat(
+ ve.getMessage(),
+ containsString(
+ "cannot use inference endpoint [completion-inference-id] with task type [completion] within a Rerank command. "
+ + "Only inference endpoints with the task type [rerank] are supported"
+ )
+ );
+ }
+
+ {
+ VerificationException ve = expectThrows(
+ VerificationException.class,
+ () -> analyze(
+ "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `error-inference-id`",
+ "mapping-books.json"
+ )
+
+ );
+ assertThat(ve.getMessage(), containsString("error with inference resolution"));
+ }
+
+ {
+ VerificationException ve = expectThrows(
+ VerificationException.class,
+ () -> analyze(
+ "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `unknown-inference-id`",
+ "mapping-books.json"
+ )
+
+ );
+ assertThat(ve.getMessage(), containsString("unresolved inference [unknown-inference-id]"));
+ }
+ }
+
+ public void testResolveRerankFields() {
+ assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ {
+ // Single field.
+ LogicalPlan plan = analyze("""
+ FROM books METADATA _score
+ | WHERE title:"italian food recipe" OR description:"italian food recipe"
+ | KEEP description, title, year, _score
+ | DROP description
+ | RERANK "italian food recipe" ON title WITH `reranking-inference-id`
+ """, "mapping-books.json");
+
+ Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
+ Rerank rerank = as(limit.child(), Rerank.class);
+ EsqlProject keep = as(rerank.child(), EsqlProject.class);
+ EsqlProject drop = as(keep.child(), EsqlProject.class);
+ Filter filter = as(drop.child(), Filter.class);
+ EsRelation relation = as(filter.child(), EsRelation.class);
+
+ Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
+ assertThat(titleAttribute, notNullValue());
+
+ assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
+ assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
+ assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
+ assertThat(
+ rerank.scoreAttribute(),
+ equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
+ );
+ }
+
+ {
+ // Multiple fields.
+ LogicalPlan plan = analyze("""
+ FROM books METADATA _score
+ | WHERE title:"food"
+ | RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year WITH `reranking-inference-id`
+ """, "mapping-books.json");
+
+ Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
+ Rerank rerank = as(limit.child(), Rerank.class);
+ Filter filter = as(rerank.child(), Filter.class);
+ EsRelation relation = as(filter.child(), EsRelation.class);
+
+ assertThat(rerank.queryText(), equalTo(string("food")));
+ assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
+
+ assertThat(rerank.rerankFields(), hasSize(3));
+ Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
+ assertThat(titleAttribute, notNullValue());
+ assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute)));
+
+ Attribute descriptionAttribute = relation.output()
+ .stream()
+ .filter(attribute -> attribute.name().equals("description"))
+ .findFirst()
+ .get();
+ assertThat(descriptionAttribute, notNullValue());
+ Alias descriptionAlias = rerank.rerankFields().get(1);
+ assertThat(descriptionAlias.name(), equalTo("description"));
+ assertThat(
+ as(descriptionAlias.child(), Substring.class).children(),
+ equalTo(List.of(descriptionAttribute, literal(0), literal(100)))
+ );
+
+ Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
+ assertThat(yearAttribute, notNullValue());
+ assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
+ assertThat(
+ rerank.scoreAttribute(),
+ equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
+ );
+ }
+
+ {
+ VerificationException ve = expectThrows(
+ VerificationException.class,
+ () -> analyze(
+ "FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH `reranking-inference-id`",
+ "mapping-books.json"
+ )
+
+ );
+ assertThat(ve.getMessage(), containsString("Unknown column [missingField]"));
+ }
+ }
+
+ public void testResolveRerankScoreField() {
+ assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ {
+ // When the metadata field is required in FROM, it is reused.
+ LogicalPlan plan = analyze("""
+ FROM books METADATA _score
+ | WHERE title:"italian food recipe" OR description:"italian food recipe"
+ | RERANK "italian food recipe" ON title WITH `reranking-inference-id`
+ """, "mapping-books.json");
+
+ Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
+ Rerank rerank = as(limit.child(), Rerank.class);
+ Filter filter = as(rerank.child(), Filter.class);
+ EsRelation relation = as(filter.child(), EsRelation.class);
+
+ Attribute metadataScoreAttribute = relation.output()
+ .stream()
+ .filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
+ .findFirst()
+ .get();
+ assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
+ assertThat(rerank.output(), hasItem(metadataScoreAttribute));
+ }
+
+ {
+ // When the metadata field is not required in FROM, it is added to the output of RERANK
+ LogicalPlan plan = analyze("""
+ FROM books
+ | WHERE title:"italian food recipe" OR description:"italian food recipe"
+ | RERANK "italian food recipe" ON title WITH `reranking-inference-id`
+ """, "mapping-books.json");
+
+ Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
+ Rerank rerank = as(limit.child(), Rerank.class);
+ Filter filter = as(rerank.child(), Filter.class);
+ EsRelation relation = as(filter.child(), EsRelation.class);
+
+ assertThat(relation.output().stream().noneMatch(attr -> attr.name().equals(MetadataAttribute.SCORE)), is(true));
+ assertThat(rerank.scoreAttribute(), equalTo(MetadataAttribute.create(EMPTY, MetadataAttribute.SCORE)));
+ assertThat(rerank.output(), hasItem(rerank.scoreAttribute()));
+ }
+ }
+
@Override
protected IndexAnalyzers createDefaultIndexAnalyzers() {
return super.createDefaultIndexAnalyzers();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java
index e458fb009d5c3..81022c5d69c35 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java
@@ -35,6 +35,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -45,7 +46,7 @@ public class ParsingTests extends ESTestCase {
private final IndexResolution defaultIndex = loadIndexResolution("mapping-basic.json");
private final Analyzer defaultAnalyzer = new Analyzer(
- new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution()),
+ new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution(), emptyInferenceResolution()),
TEST_VERIFIER
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
index cf2de30e44456..68a6f38cdd69a 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
@@ -34,6 +34,7 @@
import java.util.List;
import java.util.Objects;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
import static org.hamcrest.Matchers.containsString;
@@ -90,7 +91,13 @@ public EsqlFunctionRegistry snapshotRegistry() {
private static Analyzer analyzer(EsqlFunctionRegistry registry, License.OperationMode operationMode) {
return new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, registry, analyzerDefaultMapping(), defaultEnrichResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ registry,
+ analyzerDefaultMapping(),
+ defaultEnrichResolution(),
+ emptyInferenceResolution()
+ ),
new Verifier(new Metrics(new EsqlFunctionRegistry()), getLicenseState(operationMode))
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java
new file mode 100644
index 0000000000000..17ecbd0b32393
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
+
+import java.util.List;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class InferenceRunnerTests extends ESTestCase {
+ public void testResolveInferenceIds() throws Exception {
+ InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
+ List inferencePlans = List.of(mockInferencePlan("rerank-plan"));
+ SetOnce inferenceResolutionSetOnce = new SetOnce<>();
+
+ inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
+ throw new RuntimeException(e);
+ }));
+
+ assertBusy(() -> {
+ InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
+ assertNotNull(inferenceResolution);
+ assertThat(inferenceResolution.resolvedInferences(), contains(new ResolvedInference("rerank-plan", TaskType.RERANK)));
+ assertThat(inferenceResolution.hasError(), equalTo(false));
+ });
+ }
+
+ public void testResolveMultipleInferenceIds() throws Exception {
+ InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
+ List inferencePlans = List.of(
+ mockInferencePlan("rerank-plan"),
+ mockInferencePlan("rerank-plan"),
+ mockInferencePlan("completion-plan")
+ );
+ SetOnce inferenceResolutionSetOnce = new SetOnce<>();
+
+ inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
+ throw new RuntimeException(e);
+ }));
+
+ assertBusy(() -> {
+ InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
+ assertNotNull(inferenceResolution);
+
+ assertThat(
+ inferenceResolution.resolvedInferences(),
+ contains(
+ new ResolvedInference("rerank-plan", TaskType.RERANK),
+ new ResolvedInference("completion-plan", TaskType.COMPLETION)
+ )
+ );
+ assertThat(inferenceResolution.hasError(), equalTo(false));
+ });
+ }
+
+ public void testResolveMissingInferenceIds() throws Exception {
+ InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
+ List inferencePlans = List.of(mockInferencePlan("missing-plan"));
+
+ SetOnce inferenceResolutionSetOnce = new SetOnce<>();
+
+ inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
+ throw new RuntimeException(e);
+ }));
+
+ assertBusy(() -> {
+ InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
+ assertNotNull(inferenceResolution);
+
+ assertThat(inferenceResolution.resolvedInferences(), empty());
+ assertThat(inferenceResolution.hasError(), equalTo(true));
+ assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found"));
+ });
+ }
+
+ @SuppressWarnings({ "unchecked", "raw-types" })
+ private static Client mockClient() {
+ Client client = mock(Client.class);
+ doAnswer(i -> {
+ GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class);
+ ActionListener listener = (ActionListener) i.getArgument(2, ActionListener.class);
+ ActionResponse response = getInferenceModelResponse(request);
+
+ if (response == null) {
+ listener.onFailure(new ResourceNotFoundException("inference endpoint not found"));
+ } else {
+ listener.onResponse(response);
+ }
+
+ return null;
+ }).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any());
+ return client;
+ }
+
+ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) {
+ GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class);
+
+ if (request.getInferenceEntityId().equals("rerank-plan")) {
+ when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("rerank-plan", TaskType.RERANK)));
+ return response;
+ }
+
+ if (request.getInferenceEntityId().equals("completion-plan")) {
+ when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("completion-plan", TaskType.COMPLETION)));
+ return response;
+ }
+
+ return null;
+ }
+
+ private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {
+ return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class));
+ }
+
+ private static InferencePlan mockInferencePlan(String inferenceId) {
+ InferencePlan plan = mock(InferencePlan.class);
+ when(plan.inferenceId()).thenReturn(new Literal(Source.EMPTY, inferenceId, DataType.KEYWORD));
+ return plan;
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java
new file mode 100644
index 0000000000000..2833f3fab0d7a
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java
@@ -0,0 +1,297 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.logging.LoggerMessageFormat;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.AsyncOperator;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.compute.operator.SourceOperator;
+import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
+import org.elasticsearch.compute.test.OperatorTestCase;
+import org.elasticsearch.compute.test.RandomBlock;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.threadpool.FixedExecutorBuilder;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.hamcrest.Matcher;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class RerankOperatorTests extends OperatorTestCase {
+ private static final String ESQL_TEST_EXECUTOR = "esql_test_executor";
+ private static final String SIMPLE_INFERENCE_ID = "test_reranker";
+ private static final String SIMPLE_QUERY = "query text";
+ private ThreadPool threadPool;
+ private List inputChannelElementTypes;
+ private XContentRowEncoder.Factory rowEncoderFactory;
+ private int scoreChannel;
+
+ @Before
+ private void initChannels() {
+ int channelCount = randomIntBetween(2, 10);
+ scoreChannel = randomIntBetween(0, channelCount - 1);
+ inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList());
+ rowEncoderFactory = mockRowEncoderFactory();
+ }
+
+ @Before
+ public void setThreadPool() {
+ int numThreads = randomBoolean() ? 1 : between(2, 16);
+ threadPool = new TestThreadPool(
+ "test",
+ new FixedExecutorBuilder(Settings.EMPTY, ESQL_TEST_EXECUTOR, numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT)
+ );
+ }
+
+ @After
+ public void shutdownThreadPool() {
+ terminate(threadPool);
+ }
+
+ @Override
+ protected Operator.OperatorFactory simple() {
+ InferenceRunner inferenceRunner = mockedSimpleInferenceRunner();
+ return new RerankOperator.Factory(inferenceRunner, SIMPLE_INFERENCE_ID, SIMPLE_QUERY, rowEncoderFactory, scoreChannel);
+ }
+
+ private InferenceRunner mockedSimpleInferenceRunner() {
+ InferenceRunner inferenceRunner = mock(InferenceRunner.class);
+ when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArgument(
+ 1,
+ ActionListener.class
+ );
+ InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
+ when(inferenceResponse.getResults()).thenReturn(
+ mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
+ );
+ listener.onResponse(inferenceResponse);
+ return null;
+ }).when(inferenceRunner).doInference(any(), any());
+
+ return inferenceRunner;
+ }
+
+ private RankedDocsResults mockedRankedDocResults(InferenceAction.Request request) {
+ List rankedDocs = new ArrayList<>();
+ for (int rank = 0; rank < request.getInput().size(); rank++) {
+ if (rank % 10 != 0) {
+ rankedDocs.add(new RankedDocsResults.RankedDoc(rank, 1f / rank, request.getInput().get(rank)));
+ }
+ }
+ return new RankedDocsResults(rankedDocs);
+ }
+
+ @Override
+ protected Matcher expectedDescriptionOfSimple() {
+ return expectedToStringOfSimple();
+ }
+
+ @Override
+ protected Matcher expectedToStringOfSimple() {
+ return equalTo(
+ "RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]"
+ );
+ }
+
+ @Override
+ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+ return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
+ @Override
+ protected int remaining() {
+ return size - currentPosition;
+ }
+
+ @Override
+ protected Page createPage(int positionOffset, int length) {
+ Block[] blocks = new Block[inputChannelElementTypes.size()];
+ try {
+ currentPosition += length;
+ for (int b = 0; b < inputChannelElementTypes.size(); b++) {
+ blocks[b] = RandomBlock.randomBlock(
+ blockFactory,
+ inputChannelElementTypes.get(b),
+ length,
+ randomBoolean(),
+ 0,
+ 10,
+ 0,
+ 10
+ ).block();
+ }
+ return new Page(blocks);
+ } catch (Exception e) {
+ Releasables.closeExpectNoException(blocks);
+ throw (e);
+ }
+ }
+ };
+ }
+
+ /**
+ * Ensures that the Operator.Status of this operator has the standard fields.
+ */
+ public void testOperatorStatus() throws IOException {
+ DriverContext driverContext = driverContext();
+ try (var operator = simple().get(driverContext)) {
+ AsyncOperator.Status status = asInstanceOf(AsyncOperator.Status.class, operator.status());
+
+ assertThat(status, notNullValue());
+ assertThat(status.receivedPages(), equalTo(0L));
+ assertThat(status.completedPages(), equalTo(0L));
+ assertThat(status.procesNanos(), greaterThanOrEqualTo(0L));
+ }
+ }
+
+ @Override
+ protected void assertSimpleOutput(List inputPages, List resultPages) {
+ assertThat(inputPages, hasSize(resultPages.size()));
+
+ for (int pageId = 0; pageId < inputPages.size(); pageId++) {
+ Page inputPage = inputPages.get(pageId);
+ Page resultPage = resultPages.get(pageId);
+
+ // Check all rows are present and the output shape is unchanged.
+ assertThat(inputPage.getPositionCount(), equalTo(resultPage.getPositionCount()));
+ assertThat(inputPage.getBlockCount(), equalTo(resultPage.getBlockCount()));
+
+ BytesRef readBuffer = new BytesRef();
+
+ for (int channel = 0; channel < inputPage.getBlockCount(); channel++) {
+ Block inputBlock = inputPage.getBlock(channel);
+ Block resultBlock = resultPage.getBlock(channel);
+
+ assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount()));
+ assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType()));
+
+ if (channel == scoreChannel) {
+ assertExpectedScore(asInstanceOf(DoubleBlock.class, resultBlock));
+ } else {
+ switch (inputBlock.elementType()) {
+ case BOOLEAN -> assertBlockContentEquals(inputBlock, resultBlock, BooleanBlock::getBoolean, BooleanBlock.class);
+ case INT -> assertBlockContentEquals(inputBlock, resultBlock, IntBlock::getInt, IntBlock.class);
+ case LONG -> assertBlockContentEquals(inputBlock, resultBlock, LongBlock::getLong, LongBlock.class);
+ case FLOAT -> assertBlockContentEquals(inputBlock, resultBlock, FloatBlock::getFloat, FloatBlock.class);
+ case DOUBLE -> assertBlockContentEquals(inputBlock, resultBlock, DoubleBlock::getDouble, DoubleBlock.class);
+ case BYTES_REF -> assertByteRefsBlockContentEquals(inputBlock, resultBlock, readBuffer);
+ default -> throw new AssertionError(
+ LoggerMessageFormat.format("Unexpected block type {}", inputBlock.elementType())
+ );
+ }
+ }
+ }
+ }
+ }
+
+ private int inputChannelCount() {
+ return inputChannelElementTypes.size();
+ }
+
+ private ElementType randomElementType(int channel) {
+ return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
+ }
+
+ private XContentRowEncoder.Factory mockRowEncoderFactory() {
+ XContentRowEncoder.Factory factory = mock(XContentRowEncoder.Factory.class);
+ doAnswer(factoryInvocation -> {
+ DriverContext driverContext = factoryInvocation.getArgument(0, DriverContext.class);
+ XContentRowEncoder rowEncoder = mock(XContentRowEncoder.class);
+ doAnswer(encoderInvocation -> {
+ Page inputPage = encoderInvocation.getArgument(0, Page.class);
+ return driverContext.blockFactory()
+ .newConstantBytesRefBlockWith(new BytesRef(randomRealisticUnicodeOfCodepointLength(4)), inputPage.getPositionCount());
+ }).when(rowEncoder).eval(any(Page.class));
+
+ return rowEncoder;
+ }).when(factory).get(any(DriverContext.class));
+
+ return factory;
+ }
+
+ private void assertExpectedScore(DoubleBlock scoreBlockResult) {
+ assertAllPositions(scoreBlockResult, (pos) -> {
+ if (pos % 10 == 0) {
+ assertThat(scoreBlockResult.isNull(pos), equalTo(true));
+ } else {
+ assertThat(scoreBlockResult.getValueCount(pos), equalTo(1));
+ assertThat(scoreBlockResult.getDouble(scoreBlockResult.getFirstValueIndex(pos)), equalTo((double) (1f / pos)));
+ }
+ });
+ }
+
+ void assertBlockContentEquals(
+ Block input,
+ Block result,
+ BiFunction valueReader,
+ Class blockClass
+ ) {
+ V inputBlock = asInstanceOf(blockClass, input);
+ V resultBlock = asInstanceOf(blockClass, result);
+
+ assertAllPositions(inputBlock, (pos) -> {
+ if (inputBlock.isNull(pos)) {
+ assertThat(resultBlock.isNull(pos), equalTo(inputBlock.isNull(pos)));
+ } else {
+ assertThat(resultBlock.getValueCount(pos), equalTo(inputBlock.getValueCount(pos)));
+ assertThat(resultBlock.getFirstValueIndex(pos), equalTo(inputBlock.getFirstValueIndex(pos)));
+ for (int i = 0; i < inputBlock.getValueCount(pos); i++) {
+ assertThat(
+ valueReader.apply(resultBlock, resultBlock.getFirstValueIndex(pos) + i),
+ equalTo(valueReader.apply(inputBlock, inputBlock.getFirstValueIndex(pos) + i))
+ );
+ }
+ }
+ });
+ }
+
+ private void assertAllPositions(Block block, Consumer consumer) {
+ for (int pos = 0; pos < block.getPositionCount(); pos++) {
+ consumer.accept(pos);
+ }
+ }
+
+ private void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) {
+ assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class);
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java
new file mode 100644
index 0000000000000..b4dfd87224a3a
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/ResolvedInferenceTests.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.AbstractWireTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+public class ResolvedInferenceTests extends AbstractWireTestCase {
+
+ @Override
+ protected ResolvedInference createTestInstance() {
+ return new ResolvedInference(randomIdentifier(), randomTaskType());
+ }
+
+ @Override
+ protected ResolvedInference mutateInstance(ResolvedInference instance) throws IOException {
+ if (randomBoolean()) {
+ return new ResolvedInference(randomValueOtherThan(instance.inferenceId(), ESTestCase::randomIdentifier), instance.taskType());
+ }
+
+ return new ResolvedInference(instance.inferenceId(), randomValueOtherThan(instance.taskType(), this::randomTaskType));
+ }
+
+ @Override
+ protected ResolvedInference copyInstance(ResolvedInference instance, TransportVersion version) throws IOException {
+ return copyInstance(instance, getNamedWriteableRegistry(), (out, v) -> v.writeTo(out), in -> new ResolvedInference(in), version);
+ }
+
+ private TaskType randomTaskType() {
+ return randomFrom(TaskType.values());
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
index 6903e5dfce35d..ea95152d038ac 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
@@ -66,6 +66,8 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
@@ -98,7 +100,13 @@ public static void init() {
logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
}
@@ -449,7 +457,13 @@ public void testSparseDocument() throws Exception {
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
index ba996519b307f..b4e8df78d2963 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
@@ -101,6 +101,7 @@
import static org.elasticsearch.index.query.QueryBuilders.termsQuery;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@@ -184,7 +185,7 @@ private Analyzer makeAnalyzer(String mappingFileName, EnrichResolution enrichRes
IndexResolution getIndexResult = IndexResolution.valid(test);
return new Analyzer(
- new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution),
+ new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution, emptyInferenceResolution()),
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index 4a5c3381351e9..b514d51d5ca02 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -155,6 +155,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptySource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
@@ -248,7 +249,8 @@ public static void init() {
new EsqlFunctionRegistry(),
getIndexResult,
defaultLookupResolution(),
- enrichResolution
+ enrichResolution,
+ emptyInferenceResolution()
),
TEST_VERIFIER
);
@@ -258,7 +260,13 @@ public static void init() {
EsIndex airports = new EsIndex("airports", mappingAirports, Map.of("airports", IndexMode.STANDARD));
IndexResolution getIndexResultAirports = IndexResolution.valid(airports);
analyzerAirports = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResultAirports,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
@@ -267,7 +275,13 @@ public static void init() {
EsIndex types = new EsIndex("types", mappingTypes, Map.of("types", IndexMode.STANDARD));
IndexResolution getIndexResultTypes = IndexResolution.valid(types);
analyzerTypes = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResultTypes,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
@@ -276,14 +290,26 @@ public static void init() {
EsIndex extra = new EsIndex("extra", mappingExtra, Map.of("extra", IndexMode.STANDARD));
IndexResolution getIndexResultExtra = IndexResolution.valid(extra);
analyzerExtra = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultExtra, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResultExtra,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
metricMapping = loadMapping("k8s-mappings.json");
var metricsIndex = IndexResolution.valid(new EsIndex("k8s", metricMapping, Map.of("k8s", IndexMode.TIME_SERIES)));
metricsAnalyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), metricsIndex, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ metricsIndex,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
@@ -298,7 +324,13 @@ public static void init() {
)
);
multiIndexAnalyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), multiIndex, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ multiIndex,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
}
@@ -5268,7 +5300,13 @@ public void testEmptyMappingIndex() {
EsIndex empty = new EsIndex("empty_test", emptyMap(), Map.of());
IndexResolution getIndexResultAirports = IndexResolution.valid(empty);
var analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResultAirports,
+ enrichResolution,
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
index 2af859bfabc31..19e449d04f6ba 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
@@ -164,6 +164,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
@@ -356,7 +357,7 @@ TestDataSource makeTestDataSource(
EsIndex index = new EsIndex(indexName, mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(index);
Analyzer analyzer = new Analyzer(
- new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution),
+ new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution, emptyInferenceResolution()),
TEST_VERIFIER
);
return new TestDataSource(mapping, index, analyzer, stats);
@@ -7673,6 +7674,7 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP
() -> exchangeSinkHandler.createExchangeSink(() -> {}),
null,
null,
+ null,
new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null),
List.of()
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateInlineEvalsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateInlineEvalsTests.java
index 8f2063146cabd..82e8da3ddef78 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateInlineEvalsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateInlineEvalsTests.java
@@ -40,6 +40,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;
@@ -63,7 +64,8 @@ public static void init() {
new EsqlFunctionRegistry(),
getIndexResult,
defaultLookupResolution(),
- new EnrichResolution()
+ new EnrichResolution(),
+ defaultInferenceResolution()
),
TEST_VERIFIER
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java
index 564a88cc8a7e9..25754684bf4f3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java
@@ -30,6 +30,10 @@ public void testDevelopmentMatch() throws Exception {
parse("row a = 1 | match foo", "match");
}
+ public void testDevelopmentRerank() {
+ parse("row a = 1 | rerank \"foo\" ON title WITH reranker", "rerank");
+ }
+
void parse(String query, String errorMessage) {
ParsingException pe = expectThrows(ParsingException.class, () -> parser().createStatement(query));
assertThat(pe.getMessage(), containsString("mismatched input '" + errorMessage + "'"));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index 515564052e79b..f3f68b484124b 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -64,6 +64,7 @@
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@@ -3332,6 +3333,87 @@ public void testPreserveParanthesis() {
expectError("explain [row x = 1", "line 1:19: missing ']' at ''");
}
+ public void testRerankSingleField() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceID");
+ var rerank = as(plan, Rerank.class);
+
+ assertThat(rerank.queryText(), equalTo(literalString("query text")));
+ assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
+ assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
+ }
+
+ public void testRerankMultipleFields() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH inferenceID");
+ var rerank = as(plan, Rerank.class);
+
+ assertThat(rerank.queryText(), equalTo(literalString("query text")));
+ assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
+ assertThat(
+ rerank.rerankFields(),
+ equalTo(
+ List.of(
+ alias("title", attribute("title")),
+ alias("description", attribute("description")),
+ alias("authors_renamed", attribute("authors"))
+ )
+ )
+ );
+ }
+
+ public void testRerankComputedFields() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ var plan = processingCommand("RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH inferenceID");
+ var rerank = as(plan, Rerank.class);
+
+ assertThat(rerank.queryText(), equalTo(literalString("query text")));
+ assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
+ assertThat(
+ rerank.rerankFields(),
+ equalTo(
+ List.of(
+ alias("title", attribute("title")),
+ alias("short_description", function("SUBSTRING", List.of(attribute("description"), integer(0), integer(100))))
+ )
+ )
+ );
+ }
+
+ public void testRerankWithPositionalParameters() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker")));
+ var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class);
+
+ assertThat(rerank.queryText(), equalTo(literalString("query text")));
+ assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
+ assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
+ }
+
+ public void testRerankWithNamedParameters() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+
+ var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker")));
+ var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class);
+
+ assertThat(rerank.queryText(), equalTo(literalString("query text")));
+ assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
+ assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
+ }
+
+ public void testInvalidRerank() {
+ assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
+ expectError("FROM foo* | RERANK ON title WITH inferenceId", "line 1:20: mismatched input 'ON' expecting {QUOTED_STRING");
+
+ expectError("FROM foo* | RERANK \"query text\" WITH inferenceId", "line 1:33: mismatched input 'WITH' expecting 'on'");
+
+ expectError("FROM foo* | RERANK \"query text\" ON title", "line 1:41: mismatched input '' expecting {'and',");
+ }
+
static Alias alias(String name, Expression value) {
return new Alias(EMPTY, name, value);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java
new file mode 100644
index 0000000000000..1bb8bab502f92
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/inference/RerankSerializationTests.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.logical.inference;
+
+import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.AliasTests;
+import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
+
+public class RerankSerializationTests extends AbstractLogicalPlanSerializationTests {
+ @Override
+ protected Rerank createTestInstance() {
+ Source source = randomSource();
+ LogicalPlan child = randomChild(0);
+ return new Rerank(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
+ }
+
+ @Override
+ protected Rerank mutateInstance(Rerank instance) throws IOException {
+ LogicalPlan child = instance.child();
+ Expression inferenceId = instance.inferenceId();
+ Expression queryText = instance.queryText();
+ List fields = instance.rerankFields();
+
+ switch (between(0, 3)) {
+ case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
+ case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankSerializationTests.randomIdentifier()));
+ case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankSerializationTests.randomIdentifier()));
+ case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
+ }
+ return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute());
+ }
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
+
+ private List randomFields() {
+ return randomList(0, 10, AliasTests::randomAlias);
+ }
+
+ private Literal string(String value) {
+ return new Literal(EMPTY, value, DataType.KEYWORD);
+ }
+
+ private Attribute scoreAttribute() {
+ return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java
new file mode 100644
index 0000000000000..ecdbb1a1b4fd0
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankExecSerializationTests.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference;
+
+import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.AliasTests;
+import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
+
+public class RerankExecSerializationTests extends AbstractPhysicalPlanSerializationTests {
+ @Override
+ protected RerankExec createTestInstance() {
+ Source source = randomSource();
+ PhysicalPlan child = randomChild(0);
+ return new RerankExec(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
+ }
+
+ @Override
+ protected RerankExec mutateInstance(RerankExec instance) throws IOException {
+ PhysicalPlan child = instance.child();
+ Expression inferenceId = instance.inferenceId();
+ Expression queryText = instance.queryText();
+ List fields = instance.rerankFields();
+
+ switch (between(0, 3)) {
+ case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
+ case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankExecSerializationTests.randomIdentifier()));
+ case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankExecSerializationTests.randomIdentifier()));
+ case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
+ }
+ return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute());
+ }
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
+
+ private List randomFields() {
+ return randomList(0, 10, AliasTests::randomAlias);
+ }
+
+ static Literal string(String value) {
+ return new Literal(EMPTY, value, DataType.KEYWORD);
+ }
+
+ private Attribute scoreAttribute() {
+ return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
index 4c1b009e847ed..04d2cac31af55 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
@@ -50,6 +50,7 @@
import static org.elasticsearch.index.query.QueryBuilders.rangeQuery;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@@ -84,7 +85,13 @@ public static void init() {
mapper = new Mapper();
analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ EsqlTestUtils.emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
index 6a7571d6964c9..e07f6bc51e286 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
@@ -232,6 +232,7 @@ private LocalExecutionPlanner planner() throws IOException {
null,
null,
null,
+ null,
esPhysicalOperationProviders(shardContexts),
shardContexts
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/QueryTranslatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/QueryTranslatorTests.java
index 64f073310d3e6..02b108dcf6adb 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/QueryTranslatorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/QueryTranslatorTests.java
@@ -13,7 +13,6 @@
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
-import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
@@ -28,6 +27,8 @@
import java.util.List;
import java.util.Map;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.hamcrest.Matchers.containsString;
@@ -46,7 +47,13 @@ private static Analyzer makeAnalyzer(String mappingFileName) {
IndexResolution getIndexResult = IndexResolution.valid(test);
return new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, new EnrichResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
index e58824290c49e..9b7615d0cc37e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
@@ -39,6 +39,7 @@
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
@@ -191,7 +192,13 @@ static LogicalPlan parse(String query) {
IndexResolution getIndexResult = IndexResolution.valid(test);
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
index fac3495697da8..abf9b527f008d 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
@@ -44,6 +44,7 @@
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@@ -292,7 +293,13 @@ static LogicalPlan parse(String query) {
IndexResolution getIndexResult = IndexResolution.valid(test);
var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG, FoldContext.small()));
var analyzer = new Analyzer(
- new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
+ new AnalyzerContext(
+ EsqlTestUtils.TEST_CFG,
+ new EsqlFunctionRegistry(),
+ getIndexResult,
+ emptyPolicyResolution(),
+ emptyInferenceResolution()
+ ),
TEST_VERIFIER
);
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java
index 7d4a120668a8b..34e2af8034527 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java
@@ -86,7 +86,7 @@ public TestServiceModel parsePersistedConfigWithSecrets(
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
var taskSettingsMap = getTaskSettingsMap(config);
- var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
+ var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}
@@ -99,11 +99,15 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettingsMap) {
+ return TestTaskSettings.fromMap(taskSettingsMap);
+ }
+
protected abstract ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap);
@Override
@@ -149,15 +153,15 @@ public TestServiceModel(
TaskType taskType,
String service,
ServiceSettings serviceSettings,
- TestTaskSettings taskSettings,
+ TaskSettings taskSettings,
TestSecretSettings secretSettings
) {
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
}
@Override
- public TestTaskSettings getTaskSettings() {
- return (TestTaskSettings) super.getTaskSettings();
+ public TaskSettings getTaskSettings() {
+ return super.getTaskSettings();
}
@Override
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java
index eef0da909f529..1d04aab022f91 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java
@@ -45,6 +45,11 @@ public List getNamedWriteables() {
TestRerankingServiceExtension.TestServiceSettings.NAME,
TestRerankingServiceExtension.TestServiceSettings::new
),
+ new NamedWriteableRegistry.Entry(
+ TaskSettings.class,
+ TestRerankingServiceExtension.TestTaskSettings.NAME,
+ TestRerankingServiceExtension.TestTaskSettings::new
+ ),
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
TestStreamingCompletionServiceExtension.TestServiceSettings.NAME,
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
index 989726443ecf4..b496ea783c002 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
@@ -27,6 +27,7 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
+import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
@@ -43,6 +44,8 @@
import java.util.List;
import java.util.Map;
+import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random;
+
public class TestRerankingServiceExtension implements InferenceServiceExtension {
@Override
@@ -84,11 +87,15 @@ public void parseRequestConfig(
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
var taskSettingsMap = getTaskSettingsMap(config);
- var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
+ var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
}
+ protected TaskSettings getTasksSettingsFromMap(Map taskSettingsMap) {
+ return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
+ }
+
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
@@ -107,13 +114,15 @@ public void infer(
@Nullable Integer topN,
List input,
boolean stream,
- Map taskSettings,
+ Map taskSettingsMap,
InputType inputType,
TimeValue timeout,
ActionListener listener
) {
+ TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap);
+
switch (model.getConfigurations().getTaskType()) {
- case ANY, RERANK -> listener.onResponse(makeResults(input));
+ case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -151,7 +160,7 @@ public void chunkedInfer(
);
}
- private RankedDocsResults makeResults(List input) {
+ private RankedDocsResults makeResults(List input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
int totalResults = input.size();
try {
List results = new ArrayList<>();
@@ -161,17 +170,19 @@ private RankedDocsResults makeResults(List input) {
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
} catch (NumberFormatException ex) {
List results = new ArrayList<>();
- float minScore = random.nextFloat(-1f, 1f);
- float resultDiff = 0.2f;
+
+ float minScore = taskSettings.minScore();
+ float resultDiff = taskSettings.resultDiff();
for (int i = 0; i < input.size(); i++) {
- results.add(
- new RankedDocsResults.RankedDoc(
- totalResults - 1 - i,
- minScore + resultDiff * (totalResults - i),
- input.get(totalResults - 1 - i)
- )
- );
+ float relevanceScore = minScore + resultDiff * (totalResults - i);
+ String inputText = input.get(totalResults - 1 - i);
+ if (taskSettings.useTextLength()) {
+ relevanceScore = 1f / inputText.length();
+ }
+ results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
}
+ // Ensure result are sorted by descending score
+ results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
return new RankedDocsResults(results);
}
}
@@ -208,6 +219,77 @@ public static InferenceServiceConfiguration get() {
}
}
+ public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings {
+
+ static final String NAME = "test_reranking_task_settings";
+
+ public static TestTaskSettings fromMap(Map map) {
+ boolean useTextLength = false;
+ float minScore = random.nextFloat(-1f, 1f);
+ float resultDiff = 0.2f;
+
+ if (map.containsKey("use_text_length")) {
+ useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString());
+ }
+
+ if (map.containsKey("min_score")) {
+ minScore = Float.parseFloat(map.remove("min_score").toString());
+ }
+
+ if (map.containsKey("result_diff")) {
+ resultDiff = Float.parseFloat(map.remove("result_diff").toString());
+ }
+
+ return new TestTaskSettings(useTextLength, minScore, resultDiff);
+ }
+
+ public TestTaskSettings(StreamInput in) throws IOException {
+ this(in.readBoolean(), in.readFloat(), in.readFloat());
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return false;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeBoolean(useTextLength);
+ out.writeFloat(minScore);
+ out.writeFloat(resultDiff);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field("use_text_length", useTextLength);
+ builder.field("min_score", minScore);
+ builder.field("result_diff", resultDiff);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
+ }
+
+ @Override
+ public TaskSettings updatedTaskSettings(Map newSettingsMap) {
+ TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap));
+ return new TestTaskSettings(
+ newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength,
+ newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore,
+ newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff
+ );
+ }
+ }
+
public record TestServiceSettings(String modelId) implements ServiceSettings {
static final String NAME = "test_reranking_service_settings";