Skip to content

Commit 92bec2f

Browse files
committed
Rerank command analysis and verification.
1 parent 7943a27 commit 92bec2f

28 files changed

+495
-81
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
6767
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
6868
import org.elasticsearch.xpack.esql.index.EsIndex;
69+
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
70+
import org.elasticsearch.xpack.esql.inference.InferenceService;
6971
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
7072
import org.elasticsearch.xpack.esql.parser.QueryParam;
7173
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -374,7 +376,8 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
374376
null,
375377
mock(ClusterService.class),
376378
mock(IndexNameExpressionResolver.class),
377-
null
379+
null,
380+
mock(InferenceService.class)
378381
);
379382

380383
private EsqlTestUtils() {}
@@ -452,6 +455,10 @@ public static EnrichResolution emptyPolicyResolution() {
452455
return new EnrichResolution();
453456
}
454457

458+
public static InferenceResolution emptyInferenceResolution() {
459+
return InferenceResolution.EMPTY;
460+
}
461+
455462
public static SearchStats statsForExistingField(String... names) {
456463
return fieldMatchingExistOrMissing(true, names);
457464
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

+47
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
6565
import org.elasticsearch.xpack.esql.index.EsIndex;
6666
import org.elasticsearch.xpack.esql.index.IndexResolution;
67+
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
6768
import org.elasticsearch.xpack.esql.parser.ParsingException;
6869
import org.elasticsearch.xpack.esql.plan.IndexPattern;
6970
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
@@ -81,6 +82,8 @@
8182
import org.elasticsearch.xpack.esql.plan.logical.Project;
8283
import org.elasticsearch.xpack.esql.plan.logical.Rename;
8384
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
85+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
86+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
8487
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
8588
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
8689
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
@@ -156,6 +159,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
156159
Limiter.ONCE,
157160
new ResolveTable(),
158161
new ResolveEnrich(),
162+
new ResolveInference(),
159163
new ResolveLookupTables(),
160164
new ResolveFunctions(),
161165
new ResolveForkFunctions()
@@ -393,6 +397,34 @@ private static NamedExpression createEnrichFieldExpression(
393397
}
394398
}
395399

400+
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan, AnalyzerContext> {
401+
@Override
402+
protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) {
403+
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
404+
405+
String inferenceId = plan.inferenceId().fold(FoldContext.small()).toString();
406+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
407+
408+
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
409+
return plan;
410+
} else if (resolvedInference != null) {
411+
String error = "cannot use inference endpoint ["
412+
+ inferenceId
413+
+ "] with task type ["
414+
+ resolvedInference.taskType()
415+
+ "] within a "
416+
+ plan.nodeName()
417+
+ " command. Only inference endpoints with the task type ["
418+
+ plan.taskType()
419+
+ "] are supported.";
420+
return plan.withInferenceResolutionError(inferenceId, error);
421+
} else {
422+
String error = context.inferenceResolution().getError(inferenceId);
423+
return plan.withInferenceResolutionError(inferenceId, error);
424+
}
425+
}
426+
}
427+
396428
private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {
397429

398430
@Override
@@ -498,6 +530,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
498530
return resolveFork(f, context);
499531
}
500532

533+
if (plan instanceof Rerank r) {
534+
return resolveRerank(r, childrenOutput);
535+
}
536+
501537
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
502538
}
503539

@@ -687,6 +723,17 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
687723
return new Fork(fork.source(), fork.child(), newSubPlans);
688724
}
689725

726+
private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childOutput) {
727+
List<Alias> newFields = new ArrayList<>();
728+
boolean changed = false;
729+
for (Alias field : rerank.rerankFields()) {
730+
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childOutput));
731+
changed |= result != field;
732+
}
733+
734+
return changed ? new Rerank(rerank.source(), rerank.child(), rerank.inferenceId(), rerank.queryText(), newFields) : rerank;
735+
}
736+
690737
private List<Attribute> resolveUsingColumns(List<Attribute> cols, List<Attribute> output, String side) {
691738
List<Attribute> resolved = new ArrayList<>(cols.size());
692739
for (Attribute col : cols) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/AnalyzerContext.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
1111
import org.elasticsearch.xpack.esql.index.IndexResolution;
12+
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
1213
import org.elasticsearch.xpack.esql.session.Configuration;
1314

1415
import java.util.Map;
@@ -18,16 +19,18 @@ public record AnalyzerContext(
1819
EsqlFunctionRegistry functionRegistry,
1920
IndexResolution indexResolution,
2021
Map<String, IndexResolution> lookupResolution,
21-
EnrichResolution enrichResolution
22+
EnrichResolution enrichResolution,
23+
InferenceResolution inferenceResolution
2224
) {
2325
// Currently for tests only, since most do not test lookups
2426
// TODO: make this even simpler, remove the enrichResolution for tests that do not require it (most tests)
2527
public AnalyzerContext(
2628
Configuration configuration,
2729
EsqlFunctionRegistry functionRegistry,
2830
IndexResolution indexResolution,
29-
EnrichResolution enrichResolution
31+
EnrichResolution enrichResolution,
32+
InferenceResolution inferenceResolution
3033
) {
31-
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution);
34+
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution, inferenceResolution);
3235
}
3336
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
1212
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1313
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
14+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
1415

1516
import java.util.ArrayList;
1617
import java.util.List;
@@ -23,15 +24,22 @@
2324
public class PreAnalyzer {
2425

2526
public static class PreAnalysis {
26-
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList());
27+
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList(), emptyList());
2728

2829
public final List<TableInfo> indices;
2930
public final List<Enrich> enriches;
31+
public final List<InferencePlan> inferencePlans;
3032
public final List<TableInfo> lookupIndices;
3133

32-
public PreAnalysis(List<TableInfo> indices, List<Enrich> enriches, List<TableInfo> lookupIndices) {
34+
public PreAnalysis(
35+
List<TableInfo> indices,
36+
List<Enrich> enriches,
37+
List<InferencePlan> inferencePlans,
38+
List<TableInfo> lookupIndices
39+
) {
3340
this.indices = indices;
3441
this.enriches = enriches;
42+
this.inferencePlans = inferencePlans;
3543
this.lookupIndices = lookupIndices;
3644
}
3745
}
@@ -47,17 +55,19 @@ public PreAnalysis preAnalyze(LogicalPlan plan) {
4755
protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
4856
List<TableInfo> indices = new ArrayList<>();
4957
List<Enrich> unresolvedEnriches = new ArrayList<>();
58+
List<InferencePlan> unresolvedInferencePlans = new ArrayList<>();
5059
List<TableInfo> lookupIndices = new ArrayList<>();
5160

5261
plan.forEachUp(UnresolvedRelation.class, p -> {
5362
List<TableInfo> list = p.indexMode() == IndexMode.LOOKUP ? lookupIndices : indices;
5463
list.add(new TableInfo(p.indexPattern()));
5564
});
5665
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
66+
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
5767

5868
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
5969
plan.forEachUp(LogicalPlan::setPreAnalyzed);
6070

61-
return new PreAnalysis(indices, unresolvedEnriches, lookupIndices);
71+
return new PreAnalysis(indices, unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
6272
}
6373
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
11+
12+
import java.util.Collection;
13+
import java.util.Map;
14+
15+
public class InferenceResolution {
16+
17+
public static final InferenceResolution EMPTY = new InferenceResolution();
18+
19+
private final Map<String, ResolvedInference> resolvedInferences = ConcurrentCollections.newConcurrentMap();
20+
21+
private final Map<String, String> errors = ConcurrentCollections.newConcurrentMap();
22+
23+
public ResolvedInference getResolvedInference(String inferenceId) {
24+
return resolvedInferences.get(inferenceId);
25+
}
26+
27+
public Collection<ResolvedInference> resolvedInferences() {
28+
return resolvedInferences.values();
29+
}
30+
31+
public String getError(String inferenceId) {
32+
final String error = errors.get(inferenceId);
33+
if (error != null) {
34+
return error;
35+
} else {
36+
assert false : "unresolved inference [" + inferenceId + "]";
37+
return "unresolved inference [" + inferenceId + "]";
38+
}
39+
}
40+
41+
public void addResolvedInference(ResolvedInference resolvedInference) {
42+
resolvedInferences.putIfAbsent(resolvedInference.inferenceId(), resolvedInference);
43+
}
44+
45+
public void addError(String inferenceId, String reason) {
46+
errors.putIfAbsent(inferenceId, reason);
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.client.internal.Client;
12+
import org.elasticsearch.client.internal.OriginSettingClient;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
15+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
16+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
17+
18+
import java.util.List;
19+
import java.util.Set;
20+
import java.util.concurrent.CountDownLatch;
21+
import java.util.stream.Collectors;
22+
23+
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
24+
25+
public class InferenceService {
26+
27+
private final Client client;
28+
29+
public InferenceService(Client client) {
30+
this.client = new OriginSettingClient(client, ML_ORIGIN);
31+
}
32+
33+
public void resolveInferences(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
34+
35+
if (plans.isEmpty()) {
36+
listener.onResponse(InferenceResolution.EMPTY);
37+
return;
38+
}
39+
40+
Set<String> inferenceIds = plans.stream()
41+
.map(p -> p.inferenceId().fold(FoldContext.small()).toString())
42+
.collect(Collectors.toSet());
43+
44+
CountDownLatch countDownLatch = new CountDownLatch(inferenceIds.size());
45+
InferenceResolution inferenceResolution = new InferenceResolution();
46+
47+
for (var inferenceId : inferenceIds) {
48+
client.execute(
49+
GetInferenceModelAction.INSTANCE,
50+
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
51+
ActionListener.wrap(r -> {
52+
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
53+
inferenceResolution.addResolvedInference(resolvedInference);
54+
countDownLatch.countDown();
55+
}, e -> {
56+
inferenceResolution.addError(inferenceId, e.getMessage());
57+
countDownLatch.countDown();
58+
})
59+
);
60+
}
61+
62+
try {
63+
countDownLatch.await();
64+
} catch (InterruptedException e) {
65+
throw new RuntimeException(e);
66+
}
67+
68+
listener.onResponse(inferenceResolution);
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.TaskType;
14+
15+
import java.io.IOException;
16+
17+
public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable {
18+
19+
public ResolvedInference(StreamInput in) throws IOException {
20+
this(in.readString(), in.readEnum(TaskType.class));
21+
}
22+
23+
@Override
24+
public void writeTo(StreamOutput out) throws IOException {
25+
out.writeString(inferenceId);
26+
out.writeEnum(taskType);
27+
}
28+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ public Literal visitIntegerValue(EsqlBaseParser.IntegerValueContext ctx) {
195195
}
196196

197197
@Override
198-
public String visitStringOrParameter(EsqlBaseParser.StringOrParameterContext ctx) {
198+
public Literal visitStringOrParameter(EsqlBaseParser.StringOrParameterContext ctx) {
199199
if (ctx.parameter() != null) {
200200
if (expression(ctx.parameter()) instanceof Literal lit) {
201201
if (lit.value() == null) {
@@ -205,7 +205,7 @@ public String visitStringOrParameter(EsqlBaseParser.StringOrParameterContext ctx
205205
ctx.parameter().getText()
206206
);
207207
}
208-
return lit.value().toString();
208+
return lit;
209209
}
210210

211211
throw new ParsingException(
@@ -215,7 +215,7 @@ public String visitStringOrParameter(EsqlBaseParser.StringOrParameterContext ctx
215215
);
216216
}
217217

218-
return unquote(source(ctx.string()));
218+
return visitString(ctx.string());
219219
}
220220

221221
@Override

0 commit comments

Comments
 (0)