Skip to content

Commit 9e17731

Browse files
committed
[SPARK-43190][SQL] ListQuery.childOutput should be consistent with child output
### What changes were proposed in this pull request? Update `ListQuery` to only store the number of columns of the original plan, instead of directly storing the original plan output attributes. ### Why are the changes needed? Storing the plan output attributes is troublesome as we have to maintain them and keep them in sync with the plan. For example, `DeduplicateRelations` may change the plan output, and today we do not update `ListQuery.childOutputs` to keep sync. `ListQuery.childOutputs` was added by apache#18968 . It's only used to track the original plan output attributes as subquery de-correlation may add more columns. We can do the same thing by storing the number of columns of the plan. ### Does this PR introduce _any_ user-facing change? No, there is no user-facing bug exposed. ### How was this patch tested? a new plan test Closes apache#40851 from cloud-fan/list_query. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 09a4353 commit 9e17731

File tree

10 files changed

+45
-22
lines changed

10 files changed

+45
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
24252425
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
24262426
if values.forall(_.resolved) && !l.resolved =>
24272427
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
2428-
ListQuery(plan, exprs, exprId, plan.output)
2428+
ListQuery(plan, exprs, exprId, plan.output.length)
24292429
})
24302430
InSubquery(values, expr.asInstanceOf[ListQuery])
24312431
case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,11 @@ abstract class TypeCoercionBase {
361361

362362
// Handle type casting required between value expression and subquery output
363363
// in IN subquery.
364-
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions, _))
365-
if !i.resolved && lhs.length == sub.output.length =>
364+
case i @ InSubquery(lhs, l: ListQuery)
365+
if !i.resolved && lhs.length == l.plan.output.length =>
366366
// LHS is the value expressions of IN subquery.
367367
// RHS is the subquery output.
368-
val rhs = sub.output
368+
val rhs = l.plan.output
369369

370370
val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
371371
findWiderTypeForTwo(l.dataType, r.dataType)
@@ -383,8 +383,7 @@ abstract class TypeCoercionBase {
383383
case (e, _) => e
384384
}
385385

386-
val newSub = Project(castedRhs, sub)
387-
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
386+
InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
388387
} else {
389388
i
390389
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
367367
final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)
368368

369369
override def checkInputDataTypes(): TypeCheckResult = {
370-
if (values.length != query.childOutputs.length) {
370+
if (values.length != query.numCols) {
371371
DataTypeMismatch(
372372
errorSubClass = "IN_SUBQUERY_LENGTH_MISMATCH",
373373
messageParameters = Map(
374374
"leftLength" -> values.length.toString,
375-
"rightLength" -> query.childOutputs.length.toString,
375+
"rightLength" -> query.numCols.toString,
376376
"leftColumns" -> values.map(toSQLExpr(_)).mkString(", "),
377377
"rightColumns" -> query.childOutputs.map(toSQLExpr(_)).mkString(", ")
378378
)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,16 +354,19 @@ case class ListQuery(
354354
plan: LogicalPlan,
355355
outerAttrs: Seq[Expression] = Seq.empty,
356356
exprId: ExprId = NamedExpression.newExprId,
357-
childOutputs: Seq[Attribute] = Seq.empty,
357+
// The plan of list query may have more columns after de-correlation, and we need to track the
358+
// number of the columns of the original plan, to report the data type properly.
359+
numCols: Int = -1,
358360
joinCond: Seq[Expression] = Seq.empty,
359361
hint: Option[HintInfo] = None)
360362
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
361-
override def dataType: DataType = if (childOutputs.length > 1) {
363+
def childOutputs: Seq[Attribute] = plan.output.take(numCols)
364+
override def dataType: DataType = if (numCols > 1) {
362365
childOutputs.toStructType
363366
} else {
364-
childOutputs.head.dataType
367+
plan.output.head.dataType
365368
}
366-
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
369+
override lazy val resolved: Boolean = childrenResolved && plan.resolved && numCols != -1
367370
override def nullable: Boolean = false
368371
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
369372
override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint)
@@ -373,7 +376,7 @@ case class ListQuery(
373376
plan.canonicalized,
374377
outerAttrs.map(_.canonicalized),
375378
ExprId(0),
376-
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
379+
numCols,
377380
joinCond.map(_.canonicalized))
378381
}
379382

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
109109
return filterApplicationSidePlan
110110
}
111111
val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)),
112-
ListQuery(aggregate, childOutputs = aggregate.output))
112+
ListQuery(aggregate, numCols = aggregate.output.length))
113113
Filter(filter, filterApplicationSidePlan)
114114
}
115115

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
346346
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
347347
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
348348
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
349-
case ListQuery(sub, children, exprId, childOutputs, conditions, hint) if children.nonEmpty =>
349+
case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty =>
350350
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
351351
val joinCond = getJoinCondition(newCond, conditions)
352-
ListQuery(newPlan, children, exprId, childOutputs, joinCond, hint)
352+
ListQuery(newPlan, children, exprId, numCols, joinCond, hint)
353353
case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
354354
val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
355355
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,4 +1500,28 @@ class AnalysisSuite extends AnalysisTest with Matchers {
15001500
assert(refs.map(_.output).distinct.length == 3)
15011501
}
15021502
}
1503+
1504+
test("SPARK-43190: ListQuery.childOutput should be consistent with child output") {
1505+
val listQuery1 = ListQuery(testRelation2.select($"a"))
1506+
val listQuery2 = ListQuery(testRelation2.select($"b"))
1507+
val plan = testRelation3.where($"f".in(listQuery1) && $"f".in(listQuery2)).analyze
1508+
val resolvedCondition = plan.expressions.head
1509+
val finalPlan = testRelation2.join(testRelation3).where(resolvedCondition).analyze
1510+
val resolvedListQueries = finalPlan.expressions.flatMap(_.collect {
1511+
case l: ListQuery => l
1512+
})
1513+
assert(resolvedListQueries.length == 2)
1514+
1515+
def collectLocalRelations(plan: LogicalPlan): Seq[LocalRelation] = plan.collect {
1516+
case l: LocalRelation => l
1517+
}
1518+
val localRelations = resolvedListQueries.flatMap(l => collectLocalRelations(l.plan))
1519+
assert(localRelations.length == 2)
1520+
// DeduplicateRelations should deduplicate plans in subquery expressions as well.
1521+
assert(localRelations.head.output != localRelations.last.output)
1522+
1523+
resolvedListQueries.foreach { l =>
1524+
assert(l.childOutputs == l.plan.output)
1525+
}
1526+
}
15031527
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s
7878
case e: Exists =>
7979
e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0))
8080
case l: ListQuery =>
81-
l.copy(
82-
plan = normalizeExprIds(l.plan),
83-
exprId = ExprId(0),
84-
childOutputs = l.childOutputs.map(_.withExprId(ExprId(0))))
81+
l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0))
8582
case a: AttributeReference =>
8683
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
8784
case OuterReference(a: AttributeReference) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
8383
val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)()
8484
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
8585
DynamicPruningExpression(expressions.InSubquery(
86-
Seq(value), ListQuery(aggregate, childOutputs = aggregate.output)))
86+
Seq(value), ListQuery(aggregate, numCols = aggregate.output.length)))
8787
}
8888
}
8989
}

sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla
9090

9191
val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan)
9292
DynamicPruningExpression(
93-
InSubquery(pruningKeys, ListQuery(buildQuery, childOutputs = buildQuery.output)))
93+
InSubquery(pruningKeys, ListQuery(buildQuery, numCols = buildQuery.output.length)))
9494
}
9595
}

0 commit comments

Comments
 (0)