Skip to content

Commit 0c0bf6b

Browse files
authored
Resolve Keep plan added to FORK branches (#129754)
1 parent 18c1e55 commit 0c0bf6b

File tree

3 files changed

+74
-40
lines changed
  • x-pack/plugin/esql/src

3 files changed

+74
-40
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,43 @@ public void testWithUnsupportedFieldsAndConflicts() {
873873
assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types"));
874874
}
875875

876+
public void testValidationsAfterFork() {
877+
var firstQuery = """
878+
FROM test*
879+
| FORK ( WHERE true )
880+
( WHERE true )
881+
| DROP _fork
882+
| STATS a = count_distinct(embedding)
883+
""";
884+
885+
var e = expectThrows(VerificationException.class, () -> run(firstQuery));
886+
assertTrue(
887+
e.getMessage().contains("[count_distinct(embedding)] must be [any exact type except unsigned_long, _source, or counter types]")
888+
);
889+
890+
var secondQuery = """
891+
FROM test*
892+
| FORK ( WHERE true )
893+
( WHERE true )
894+
| DROP _fork
895+
| EVAL a = substring(1, 2, 3)
896+
""";
897+
898+
e = expectThrows(VerificationException.class, () -> run(secondQuery));
899+
assertTrue(e.getMessage().contains("first argument of [substring(1, 2, 3)] must be [string], found value [1] type [integer]"));
900+
901+
var thirdQuery = """
902+
FROM test*
903+
| FORK ( WHERE true )
904+
( WHERE true )
905+
| DROP _fork
906+
| EVAL a = b + 2
907+
""";
908+
909+
e = expectThrows(VerificationException.class, () -> run(thirdQuery));
910+
assertTrue(e.getMessage().contains("Unknown column [b]"));
911+
}
912+
876913
public void testWithEvalWithConflictingTypes() {
877914
var query = """
878915
FROM test

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -788,10 +788,8 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
788788
}
789789

790790
List<String> subPlanColumns = logicalPlan.output().stream().map(Attribute::name).toList();
791-
// We need to add an explicit Keep even if the outputs align
792-
// This is because at the moment the sub plans are executed and optimized separately and the output might change
793-
// during optimizations. Once we add streaming we might not need to add a Keep when the outputs already align.
794-
if (logicalPlan instanceof Keep == false || subPlanColumns.equals(forkColumns) == false) {
791+
// We need to add an explicit EsqlProject to align the outputs.
792+
if (logicalPlan instanceof Project == false || subPlanColumns.equals(forkColumns) == false) {
795793
changed = true;
796794
List<Attribute> newOutput = new ArrayList<>();
797795
for (String attrName : forkColumns) {
@@ -801,7 +799,7 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
801799
}
802800
}
803801
}
804-
logicalPlan = new Keep(logicalPlan.source(), logicalPlan, newOutput);
802+
logicalPlan = resolveKeep(new Keep(logicalPlan.source(), logicalPlan, newOutput), logicalPlan.output());
805803
}
806804

807805
newSubPlans.add(logicalPlan);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
import org.elasticsearch.xpack.esql.plan.logical.Filter;
7575
import org.elasticsearch.xpack.esql.plan.logical.Fork;
7676
import org.elasticsearch.xpack.esql.plan.logical.Insist;
77-
import org.elasticsearch.xpack.esql.plan.logical.Keep;
7877
import org.elasticsearch.xpack.esql.plan.logical.Limit;
7978
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
8079
import org.elasticsearch.xpack.esql.plan.logical.Lookup;
@@ -3090,27 +3089,27 @@ public void testBasicFork() {
30903089
// fork branch 1
30913090
limit = as(subPlans.get(0), Limit.class);
30923091
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
3093-
Keep keep = as(limit.child(), Keep.class);
3094-
List<String> keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3095-
assertThat(keptColumns, equalTo(expectedOutput));
3096-
Eval eval = as(keep.child(), Eval.class);
3092+
EsqlProject project = as(limit.child(), EsqlProject.class);
3093+
List<String> projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3094+
assertThat(projectColumns, equalTo(expectedOutput));
3095+
Eval eval = as(project.child(), Eval.class);
30973096
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork1"))));
30983097
Filter filter = as(eval.child(), Filter.class);
30993098
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(1)));
31003099

31013100
filter = as(filter.child(), Filter.class);
31023101
assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris")));
3103-
EsqlProject project = as(filter.child(), EsqlProject.class);
3102+
project = as(filter.child(), EsqlProject.class);
31043103
var esRelation = as(project.child(), EsRelation.class);
31053104
assertThat(esRelation.indexPattern(), equalTo("test"));
31063105

31073106
// fork branch 2
31083107
limit = as(subPlans.get(1), Limit.class);
31093108
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
3110-
keep = as(limit.child(), Keep.class);
3111-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3112-
assertThat(keptColumns, equalTo(expectedOutput));
3113-
eval = as(keep.child(), Eval.class);
3109+
project = as(limit.child(), EsqlProject.class);
3110+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3111+
assertThat(projectColumns, equalTo(expectedOutput));
3112+
eval = as(project.child(), Eval.class);
31143113
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork2"))));
31153114
filter = as(eval.child(), Filter.class);
31163115
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(2)));
@@ -3124,10 +3123,10 @@ public void testBasicFork() {
31243123
// fork branch 3
31253124
limit = as(subPlans.get(2), Limit.class);
31263125
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
3127-
keep = as(limit.child(), Keep.class);
3128-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3129-
assertThat(keptColumns, equalTo(expectedOutput));
3130-
eval = as(keep.child(), Eval.class);
3126+
project = as(limit.child(), EsqlProject.class);
3127+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3128+
assertThat(projectColumns, equalTo(expectedOutput));
3129+
eval = as(project.child(), Eval.class);
31313130
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork3"))));
31323131
limit = as(eval.child(), Limit.class);
31333132
assertThat(as(limit.limit(), Literal.class).value(), equalTo(7));
@@ -3143,10 +3142,10 @@ public void testBasicFork() {
31433142
// fork branch 4
31443143
limit = as(subPlans.get(3), Limit.class);
31453144
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
3146-
keep = as(limit.child(), Keep.class);
3147-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3148-
assertThat(keptColumns, equalTo(expectedOutput));
3149-
eval = as(keep.child(), Eval.class);
3145+
project = as(limit.child(), EsqlProject.class);
3146+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3147+
assertThat(projectColumns, equalTo(expectedOutput));
3148+
eval = as(project.child(), Eval.class);
31503149
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork4"))));
31513150
orderBy = as(eval.child(), OrderBy.class);
31523151
filter = as(orderBy.child(), Filter.class);
@@ -3158,10 +3157,10 @@ public void testBasicFork() {
31583157
// fork branch 5
31593158
limit = as(subPlans.get(4), Limit.class);
31603159
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
3161-
keep = as(limit.child(), Keep.class);
3162-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3163-
assertThat(keptColumns, equalTo(expectedOutput));
3164-
eval = as(keep.child(), Eval.class);
3160+
project = as(limit.child(), EsqlProject.class);
3161+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3162+
assertThat(projectColumns, equalTo(expectedOutput));
3163+
eval = as(project.child(), Eval.class);
31653164
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork5"))));
31663165
limit = as(eval.child(), Limit.class);
31673166
assertThat(as(limit.limit(), Literal.class).value(), equalTo(9));
@@ -3193,11 +3192,11 @@ public void testForkBranchesWithDifferentSchemas() {
31933192
// fork branch 1
31943193
limit = as(subPlans.get(0), Limit.class);
31953194
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
3196-
Keep keep = as(limit.child(), Keep.class);
3197-
List<String> keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3198-
assertThat(keptColumns, equalTo(expectedOutput));
3195+
EsqlProject project = as(limit.child(), EsqlProject.class);
3196+
List<String> projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3197+
assertThat(projectColumns, equalTo(expectedOutput));
31993198

3200-
Eval eval = as(keep.child(), Eval.class);
3199+
Eval eval = as(project.child(), Eval.class);
32013200
assertEquals(eval.fields().size(), 3);
32023201

32033202
Set<String> evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
@@ -3215,7 +3214,7 @@ public void testForkBranchesWithDifferentSchemas() {
32153214
Filter filter = as(orderBy.child(), Filter.class);
32163215
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(3)));
32173216

3218-
EsqlProject project = as(filter.child(), EsqlProject.class);
3217+
project = as(filter.child(), EsqlProject.class);
32193218
filter = as(project.child(), Filter.class);
32203219
assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris")));
32213220
var esRelation = as(filter.child(), EsRelation.class);
@@ -3224,10 +3223,10 @@ public void testForkBranchesWithDifferentSchemas() {
32243223
// fork branch 2
32253224
limit = as(subPlans.get(1), Limit.class);
32263225
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
3227-
keep = as(limit.child(), Keep.class);
3228-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3229-
assertThat(keptColumns, equalTo(expectedOutput));
3230-
eval = as(keep.child(), Eval.class);
3226+
project = as(limit.child(), EsqlProject.class);
3227+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3228+
assertThat(projectColumns, equalTo(expectedOutput));
3229+
eval = as(project.child(), Eval.class);
32313230
assertEquals(eval.fields().size(), 2);
32323231
evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
32333232
assertThat(evalFieldNames, equalTo(Set.of("x", "y")));
@@ -3254,10 +3253,10 @@ public void testForkBranchesWithDifferentSchemas() {
32543253
// fork branch 3
32553254
limit = as(subPlans.get(2), Limit.class);
32563255
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
3257-
keep = as(limit.child(), Keep.class);
3258-
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3259-
assertThat(keptColumns, equalTo(expectedOutput));
3260-
eval = as(keep.child(), Eval.class);
3256+
project = as(limit.child(), EsqlProject.class);
3257+
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
3258+
assertThat(projectColumns, equalTo(expectedOutput));
3259+
eval = as(project.child(), Eval.class);
32613260
assertEquals(eval.fields().size(), 2);
32623261
evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
32633262
assertThat(evalFieldNames, equalTo(Set.of("emp_no", "first_name")));

0 commit comments

Comments
 (0)