|
64 | 64 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
65 | 65 | import org.elasticsearch.xpack.esql.index.EsIndex;
|
66 | 66 | import org.elasticsearch.xpack.esql.index.IndexResolution;
|
| 67 | +import org.elasticsearch.xpack.esql.inference.ResolvedInference; |
67 | 68 | import org.elasticsearch.xpack.esql.parser.ParsingException;
|
68 | 69 | import org.elasticsearch.xpack.esql.plan.IndexPattern;
|
69 | 70 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
|
81 | 82 | import org.elasticsearch.xpack.esql.plan.logical.Project;
|
82 | 83 | import org.elasticsearch.xpack.esql.plan.logical.Rename;
|
83 | 84 | import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
| 85 | +import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; |
| 86 | +import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; |
84 | 87 | import org.elasticsearch.xpack.esql.plan.logical.join.Join;
|
85 | 88 | import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
|
86 | 89 | import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
|
@@ -156,6 +159,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
156 | 159 | Limiter.ONCE,
|
157 | 160 | new ResolveTable(),
|
158 | 161 | new ResolveEnrich(),
|
| 162 | + new ResolveInference(), |
159 | 163 | new ResolveLookupTables(),
|
160 | 164 | new ResolveFunctions(),
|
161 | 165 | new ResolveForkFunctions()
|
@@ -393,6 +397,34 @@ private static NamedExpression createEnrichFieldExpression(
|
393 | 397 | }
|
394 | 398 | }
|
395 | 399 |
|
| 400 | + private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan, AnalyzerContext> { |
| 401 | + @Override |
| 402 | + protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) { |
| 403 | + assert plan.inferenceId().resolved() && plan.inferenceId().foldable(); |
| 404 | + |
| 405 | + String inferenceId = plan.inferenceId().fold(FoldContext.small()).toString(); |
| 406 | + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); |
| 407 | + |
| 408 | + if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) { |
| 409 | + return plan; |
| 410 | + } else if (resolvedInference != null) { |
| 411 | + String error = "cannot use inference endpoint [" |
| 412 | + + inferenceId |
| 413 | + + "] with task type [" |
| 414 | + + resolvedInference.taskType() |
| 415 | + + "] within a " |
| 416 | + + plan.nodeName() |
| 417 | + + " command. Only inference endpoints with the task type [" |
| 418 | + + plan.taskType() |
| 419 | + + "] are supported."; |
| 420 | + return plan.withInferenceResolutionError(inferenceId, error); |
| 421 | + } else { |
| 422 | + String error = context.inferenceResolution().getError(inferenceId); |
| 423 | + return plan.withInferenceResolutionError(inferenceId, error); |
| 424 | + } |
| 425 | + } |
| 426 | + } |
| 427 | + |
396 | 428 | private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {
|
397 | 429 |
|
398 | 430 | @Override
|
@@ -498,6 +530,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
|
498 | 530 | return resolveFork(f, context);
|
499 | 531 | }
|
500 | 532 |
|
| 533 | + if (plan instanceof Rerank r) { |
| 534 | + return resolveRerank(r, childrenOutput); |
| 535 | + } |
| 536 | + |
501 | 537 | return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
|
502 | 538 | }
|
503 | 539 |
|
@@ -687,6 +723,17 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
|
687 | 723 | return new Fork(fork.source(), fork.child(), newSubPlans);
|
688 | 724 | }
|
689 | 725 |
|
| 726 | + private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childOutput) { |
| 727 | + List<Alias> newFields = new ArrayList<>(); |
| 728 | + boolean changed = false; |
| 729 | + for (Alias field : rerank.rerankFields()) { |
| 730 | + Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childOutput)); |
| 731 | + changed |= result != field; |
| 732 | + } |
| 733 | + |
| 734 | + return changed ? new Rerank(rerank.source(), rerank.child(), rerank.inferenceId(), rerank.queryText(), newFields) : rerank; |
| 735 | + } |
| 736 | + |
690 | 737 | private List<Attribute> resolveUsingColumns(List<Attribute> cols, List<Attribute> output, String side) {
|
691 | 738 | List<Attribute> resolved = new ArrayList<>(cols.size());
|
692 | 739 | for (Attribute col : cols) {
|
|
0 commit comments