diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index 6402745d36005..2bd53cbfc9d30 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -884,6 +884,10 @@ public static T singleValue(Collection collection) { return collection.iterator().next(); } + public static Attribute getAttributeByName(Collection attributes, String name) { + return attributes.stream().filter(attr -> attr.name().equals(name)).findAny().orElse(null); + } + public static Map jsonEntityToMap(HttpEntity entity) throws IOException { return entityToMap(entity, XContentType.JSON); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 980350ce43d4e..428723ce00179 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -89,6 +89,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Rename; import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; import org.elasticsearch.xpack.esql.plan.logical.join.Join; @@ -488,6 +489,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) { return resolveAggregate(aggregate, childrenOutput); } + if (plan instanceof Completion c) { + return resolveCompletion(c, childrenOutput); + } + if (plan instanceof Drop d) { return resolveDrop(d, childrenOutput); } @@ -598,6 +603,21 @@ private Aggregate resolveAggregate(Aggregate aggregate, List children return aggregate; } + private LogicalPlan resolveCompletion(Completion p, List childrenOutput) { + Attribute targetField = p.targetField(); + Expression prompt = p.prompt(); + + if (targetField instanceof UnresolvedAttribute ua) { + targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT); + } + + if (prompt.resolved() == false) { + prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)); + } + + return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField); + } + private LogicalPlan resolveMvExpand(MvExpand p, List childrenOutput) { if (p.target() instanceof UnresolvedAttribute ua) { Attribute resolved = maybeResolveAttribute(ua, childrenOutput); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java index 0ebd1c7c670ac..ea8918838c70d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java @@ -11,12 +11,15 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -26,9 +29,15 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.esql.common.Failure.fail; +import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; -public class Completion extends InferencePlan implements GeneratingPlan, SortAgnostic { +public class Completion extends InferencePlan + implements + GeneratingPlan, + SortAgnostic, + PostAnalysisVerificationAware { public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion"; @@ -130,6 +139,13 @@ public boolean expressionsResolved() { return super.expressionsResolved() && prompt.resolved(); } + @Override + public void postAnalysisVerification(Failures failures) { + if (prompt.resolved() && DataType.isString(prompt.dataType()) == false) { + failures.add(fail(prompt, "prompt must be of type [{}] but is [{}]", TEXT.typeName(), prompt.dataType().typeName())); + } + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index b0213a419a80c..5474be7efe0bb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; @@ -71,6 +72,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; @@ -92,9 +94,11 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getAttributeByName; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsIdentifier; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsPattern; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; @@ -3460,7 +3464,7 @@ public void testResolveRerankInferenceId() { { LogicalPlan plan = analyze( - " FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`", + "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`", "mapping-books.json" ); Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class); @@ -3530,16 +3534,13 @@ public void testResolveRerankFields() { Filter filter = as(drop.child(), Filter.class); EsRelation relation = as(filter.child(), EsRelation.class); - Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get(); - assertThat(titleAttribute, notNullValue()); + Attribute titleAttribute = getAttributeByName(relation.output(), "title"); + assertThat(getAttributeByName(relation.output(), "title"), notNullValue()); assertThat(rerank.queryText(), equalTo(string("italian food recipe"))); assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id"))); assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute)))); - assertThat( - rerank.scoreAttribute(), - equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get()) - ); + assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE))); } { @@ -3559,15 +3560,11 @@ public void testResolveRerankFields() { assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id"))); assertThat(rerank.rerankFields(), hasSize(3)); - Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get(); + Attribute titleAttribute = getAttributeByName(relation.output(), "title"); assertThat(titleAttribute, notNullValue()); assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute))); - Attribute descriptionAttribute = relation.output() - .stream() - .filter(attribute -> attribute.name().equals("description")) - .findFirst() - .get(); + Attribute descriptionAttribute = getAttributeByName(relation.output(), "description"); assertThat(descriptionAttribute, notNullValue()); Alias descriptionAlias = rerank.rerankFields().get(1); assertThat(descriptionAlias.name(), equalTo("description")); @@ -3576,13 +3573,11 @@ public void testResolveRerankFields() { equalTo(List.of(descriptionAttribute, literal(0), literal(100))) ); - Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get(); + Attribute yearAttribute = getAttributeByName(relation.output(), "year"); assertThat(yearAttribute, notNullValue()); assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute))); - assertThat( - rerank.scoreAttribute(), - equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get()) - ); + + assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE))); } { @@ -3614,11 +3609,7 @@ public void testResolveRerankScoreField() { Filter filter = as(rerank.child(), Filter.class); EsRelation relation = as(filter.child(), EsRelation.class); - Attribute metadataScoreAttribute = relation.output() - .stream() - .filter(attr -> attr.name().equals(MetadataAttribute.SCORE)) - .findFirst() - .get(); + Attribute metadataScoreAttribute = getAttributeByName(relation.output(), MetadataAttribute.SCORE); assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute)); assertThat(rerank.output(), hasItem(metadataScoreAttribute)); } @@ -3642,6 +3633,116 @@ public void testResolveRerankScoreField() { } } + public void testResolveCompletionInferenceId() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` + """, "mapping-books.json"); + Completion completion = as(as(plan, Limit.class).child(), Completion.class); + assertThat(completion.inferenceId(), equalTo(string("completion-inference-id"))); + } + + public void testResolveCompletionInferenceIdInvalidTaskType() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + assertError( + """ + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `reranking-inference-id` + """, + "mapping-books.json", + new QueryParams(), + "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command." + + " Only inference endpoints with the task type [completion] are supported" + ); + } + + public void testResolveCompletionInferenceMissingInferenceId() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + assertError(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `unknown-inference-id` + """, "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]"); + } + + public void testResolveCompletionInferenceIdResolutionError() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + assertError(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `error-inference-id` + """, "mapping-books.json", new QueryParams(), "error with inference resolution"); + } + + public void testResolveCompletionTargetField() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS translation + """, "mapping-books.json"); + + Completion completion = as(as(plan, Limit.class).child(), Completion.class); + assertThat(completion.targetField(), equalTo(referenceAttribute("translation", DataType.TEXT))); + } + + public void testResolveCompletionDefaultTargetField() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` + """, "mapping-books.json"); + + Completion completion = as(as(plan, Limit.class).child(), Completion.class); + assertThat(completion.targetField(), equalTo(referenceAttribute("completion", DataType.TEXT))); + } + + public void testResolveCompletionPrompt() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` + """, "mapping-books.json"); + + Completion completion = as(as(plan, Limit.class).child(), Completion.class); + EsRelation esRelation = as(completion.child(), EsRelation.class); + + assertThat( + as(completion.prompt(), Concat.class).children(), + equalTo(List.of(string("Translate the following text in French\n"), getAttributeByName(esRelation.output(), "description"))) + ); + } + + public void testResolveCompletionPromptInvalidType() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + assertError(""" + FROM books METADATA _score + | COMPLETION LENGTH(description) WITH `completion-inference-id` + """, "mapping-books.json", new QueryParams(), "prompt must be of type [text] but is [integer]"); + } + + public void testResolveCompletionOutputField() { + assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled()); + + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS description + """, "mapping-books.json"); + + Completion completion = as(as(plan, Limit.class).child(), Completion.class); + assertThat(completion.targetField(), equalTo(referenceAttribute("description", DataType.TEXT))); + + EsRelation esRelation = as(completion.child(), EsRelation.class); + assertThat(getAttributeByName(completion.output(), "description"), equalTo(completion.targetField())); + assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField()))); + } + @Override protected IndexAnalyzers createDefaultIndexAnalyzers() { return super.createDefaultIndexAnalyzers();