Skip to content

Commit 420e687

Browse files
agubichevcloud-fan
authored andcommitted
[SPARK-43780][SQL] Support correlated references in join predicates for scalar and lateral subqueries
### What changes were proposed in this pull request? This PR adds support to subqueries that involve joins with correlated references in join predicates, e.g. ``` select * from t0 join lateral (select * from t1 join t2 on t1a = t2a and t1a = t0a); ``` (full example in https://issues.apache.org/jira/browse/SPARK-43780) Currently we only handle scalar and lateral subqueries. ### Why are the changes needed? This is a valid SQL that is not yet supported by Spark SQL. ### Does this PR introduce _any_ user-facing change? Yes, previously unsupported queries become supported. ### How was this patch tested? Query and unit tests Closes apache#41301 from agubichev/spark-43780-corr-predicate. Authored-by: Andrey Gubichev <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f7002fb commit 420e687

File tree

13 files changed

+605
-11
lines changed

13 files changed

+605
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
11731173
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
11741174
case _: Filter => true
11751175
case _: Project => usingDecorrelateInnerQueryFramework
1176+
case _: Join => usingDecorrelateInnerQueryFramework
11761177
case _ => false
11771178
}
11781179

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -804,18 +804,88 @@ object DecorrelateInnerQuery extends PredicateHelper {
804804
(d.copy(child = newChild), joinCond, outerReferenceMap)
805805

806806
case j @ Join(left, right, joinType, condition, _) =>
807-
val outerReferences = collectOuterReferences(j.expressions)
808-
// Join condition containing outer references is not supported.
809-
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
810-
val newOuterReferences = parentOuterReferences ++ outerReferences
811-
val shouldPushToLeft = joinType match {
807+
// Given 'condition', computes the tuple of
808+
// (correlated, uncorrelated, equalityCond, predicates, equivalences).
809+
// 'correlated' and 'uncorrelated' are the conjuncts with (resp. without)
810+
// outer (correlated) references. Furthermore, correlated conjuncts are split
811+
// into 'equalityCond' (those that are equalities) and all rest ('predicates').
812+
// 'equivalences' track equivalent attributes given 'equalityCond'.
813+
// The split is only performed if 'shouldDecorrelatePredicates' is true.
814+
// The input parameter 'isInnerJoin' is set to true for INNER joins and helps
815+
// determine whether some predicates can be lifted up from the join (this is only
816+
// valid for inner joins).
817+
// Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C = D, the output
818+
// would be:
819+
// correlated = (A = outer(X), B > outer(Y))
820+
// uncorrelated = (C = D)
821+
// equalityCond = (A = outer(X))
822+
// predicates = (B > outer(Y))
823+
// equivalences: (A -> outer(X))
824+
def splitCorrelatedPredicate(
825+
condition: Option[Expression],
826+
isInnerJoin: Boolean,
827+
shouldDecorrelatePredicates: Boolean):
828+
(Seq[Expression], Seq[Expression], Seq[Expression],
829+
Seq[Expression], AttributeMap[Attribute]) = {
830+
// Similar to Filters above, we split the join condition (if present) into correlated
831+
// and uncorrelated predicates, and separately handle joins under set and aggregation
832+
// operations.
833+
if (shouldDecorrelatePredicates) {
834+
val conditions =
835+
if (condition.isDefined) splitConjunctivePredicates(condition.get)
836+
else Seq.empty[Expression]
837+
val (correlated, uncorrelated) = conditions.partition(containsOuter)
838+
var equivalences =
839+
if (underSetOp) AttributeMap.empty[Attribute]
840+
else collectEquivalentOuterReferences(correlated)
841+
var (equalityCond, predicates) =
842+
if (underSetOp) (Seq.empty[Expression], correlated)
843+
else correlated.partition(canPullUpOverAgg)
844+
// Fully preserve the join predicate for non-inner joins.
845+
if (!isInnerJoin) {
846+
predicates = correlated
847+
equalityCond = Seq.empty[Expression]
848+
equivalences = AttributeMap.empty[Attribute]
849+
}
850+
(correlated, uncorrelated, equalityCond, predicates, equivalences)
851+
} else {
852+
(Seq.empty[Expression],
853+
if (condition.isEmpty) Seq.empty[Expression] else Seq(condition.get),
854+
Seq.empty[Expression],
855+
Seq.empty[Expression],
856+
AttributeMap.empty[Attribute])
857+
}
858+
}
859+
860+
val shouldDecorrelatePredicates =
861+
SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
862+
if (!shouldDecorrelatePredicates) {
863+
val outerReferences = collectOuterReferences(j.expressions)
864+
// Join condition containing outer references is not supported.
865+
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
866+
}
867+
val (correlated, uncorrelated, equalityCond, predicates, equivalences) =
868+
splitCorrelatedPredicate(condition, joinType == Inner, shouldDecorrelatePredicates)
869+
val outerReferences = collectOuterReferences(j.expressions) ++
870+
collectOuterReferences(predicates)
871+
val newOuterReferences =
872+
parentOuterReferences ++ outerReferences -- equivalences.keySet
873+
var shouldPushToLeft = joinType match {
812874
case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
813875
case _ => hasOuterReferences(left)
814876
}
815877
val shouldPushToRight = joinType match {
816878
case RightOuter | FullOuter => true
817879
case _ => hasOuterReferences(right)
818880
}
881+
if (shouldDecorrelatePredicates && !shouldPushToLeft && !shouldPushToRight
882+
&& !predicates.isEmpty) {
883+
// Neither left nor right children of the join have correlations, but the join
884+
// predicate does, and the correlations can not be replaced via equivalences.
885+
// Introduce a domain join on the left side of the join
886+
// (chosen arbitrarily) to provide values for the correlated attribute reference.
887+
shouldPushToLeft = true;
888+
}
819889
val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
820890
decorrelate(left, newOuterReferences, aggregated, underSetOp)
821891
} else {
@@ -826,8 +896,13 @@ object DecorrelateInnerQuery extends PredicateHelper {
826896
} else {
827897
(right, Nil, AttributeMap.empty[Attribute])
828898
}
829-
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
830-
val newJoinCond = leftJoinCond ++ rightJoinCond
899+
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap ++
900+
equivalences
901+
val newCorrelated =
902+
if (shouldDecorrelatePredicates) {
903+
replaceOuterReferences(correlated, newOuterReferenceMap)
904+
} else Seq.empty[Expression]
905+
val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond
831906
// If we push the dependent join to both sides, we can augment the join condition
832907
// such that both sides are matched on the domain attributes. For example,
833908
// - Left Map: {outer(c1) = c1}
@@ -836,7 +911,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
836911
val augmentedConditions = leftOuterReferenceMap.flatMap {
837912
case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
838913
}
839-
val newCondition = (condition ++ augmentedConditions).reduceOption(And)
914+
val newCondition = (newCorrelated ++ uncorrelated
915+
++ augmentedConditions).reduceOption(And)
840916
val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
841917
(newJoin, newJoinCond, newOuterReferenceMap)
842918

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4370,6 +4370,16 @@ object SQLConf {
43704370
.checkValue(_ >= 0, "The threshold of cached local relations must not be negative")
43714371
.createWithDefault(64 * 1024 * 1024)
43724372

4373+
val DECORRELATE_JOIN_PREDICATE_ENABLED =
4374+
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
4375+
.internal()
4376+
.doc("Decorrelate scalar and lateral subqueries with correlated references in join " +
4377+
"predicates. This configuration is only effective when " +
4378+
"'${DECORRELATE_INNER_QUERY_ENABLED.key}' is true.")
4379+
.version("4.0.0")
4380+
.booleanConf
4381+
.createWithDefault(true)
4382+
43734383
/**
43744384
* Holds information about keys that have been deprecated.
43754385
*

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ class DecorrelateInnerQuerySuite extends PlanTest {
3535
val a3 = AttributeReference("a3", IntegerType)()
3636
val b3 = AttributeReference("b3", IntegerType)()
3737
val c3 = AttributeReference("c3", IntegerType)()
38+
val a4 = AttributeReference("a4", IntegerType)()
39+
val b4 = AttributeReference("b4", IntegerType)()
3840
val t0 = OneRowRelation()
3941
val testRelation = LocalRelation(a, b, c)
4042
val testRelation2 = LocalRelation(x, y, z)
4143
val testRelation3 = LocalRelation(a3, b3, c3)
44+
val testRelation4 = LocalRelation(a4, b4)
4245

4346
private def hasOuterReferences(plan: LogicalPlan): Boolean = {
4447
plan.exists(_.expressions.exists(SubExprUtils.containsOuter))
@@ -198,12 +201,15 @@ class DecorrelateInnerQuerySuite extends PlanTest {
198201
val innerPlan =
199202
Join(
200203
testRelation.as("t1"),
201-
Filter(OuterReference(y) === 3, testRelation),
204+
Filter(OuterReference(y) === b3, testRelation3),
202205
Inner,
203206
Some(OuterReference(x) === a),
204207
JoinHint.NONE)
205-
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) }
206-
assert(error.getMessage.contains("Correlated column is not allowed in join"))
208+
val correctAnswer =
209+
Join(
210+
testRelation.as("t1"), testRelation3,
211+
Inner, Some(a === a), JoinHint.NONE)
212+
check(innerPlan, outerPlan, correctAnswer, Seq(b3 === y, x === a))
207213
}
208214

209215
test("correlated values in project") {
@@ -454,4 +460,125 @@ class DecorrelateInnerQuerySuite extends PlanTest {
454460
DomainJoin(Seq(x), testRelation))))
455461
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
456462
}
463+
464+
test("SPARK-43780: aggregation in subquery with correlated equi-join") {
465+
// Join in the subquery is on equi-predicates, so all the correlated references can be
466+
// substituted by equivalent ones from the outer query, and domain join is not needed.
467+
val outerPlan = testRelation
468+
val innerPlan =
469+
Aggregate(
470+
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
471+
Project(Seq(x, y, a3, b3),
472+
Join(testRelation2, testRelation3, Inner,
473+
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
474+
475+
val correctAnswer =
476+
Aggregate(
477+
Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
478+
Project(Seq(x, y, a3, b3),
479+
Join(testRelation2, testRelation3, Inner, Some(And(y === y, x === a3)), JoinHint.NONE)))
480+
check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
481+
}
482+
483+
test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
484+
// Join in the subquery is on non-equi-predicate, so we introduce a DomainJoin.
485+
val outerPlan = testRelation
486+
val innerPlan =
487+
Aggregate(
488+
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
489+
Project(Seq(x, y, a3, b3),
490+
Join(testRelation2, testRelation3, Inner,
491+
Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
492+
val correctAnswer =
493+
Aggregate(
494+
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
495+
Project(Seq(x, y, a3, b3, a),
496+
Join(
497+
DomainJoin(Seq(a), testRelation2),
498+
testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
499+
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
500+
}
501+
502+
test("SPARK-43780: aggregation in subquery with correlated left join") {
503+
// Join in the subquery is on equi-predicates, so all the correlated references can be
504+
// substituted by equivalent ones from the outer query, and domain join is not needed.
505+
val outerPlan = testRelation
506+
val innerPlan =
507+
Aggregate(
508+
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
509+
Project(Seq(x, y, a3, b3),
510+
Join(testRelation2, testRelation3, LeftOuter,
511+
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
512+
513+
val correctAnswer =
514+
Aggregate(
515+
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
516+
Project(Seq(x, y, a3, b3, a),
517+
Join(DomainJoin(Seq(a), testRelation2), testRelation3, LeftOuter,
518+
Some(And(y === a, x === a3)), JoinHint.NONE)))
519+
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
520+
}
521+
522+
test("SPARK-43780: aggregation in subquery with correlated left join, " +
523+
"correlation over right side") {
524+
// Same as above, but the join predicate connects the outer reference and the column from the
525+
// right (optional) side of the left join. Domain join is still not needed.
526+
val outerPlan = testRelation
527+
val innerPlan =
528+
Aggregate(
529+
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
530+
Project(Seq(x, y, a3, b3),
531+
Join(testRelation2, testRelation3, LeftOuter,
532+
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))
533+
534+
val correctAnswer =
535+
Aggregate(
536+
Seq(b), Seq(Alias(count(Literal(1)), "a")(), b),
537+
Project(Seq(x, y, a3, b3, b),
538+
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
539+
Some(And(b === b3, x === a3)), JoinHint.NONE)))
540+
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
541+
}
542+
543+
test("SPARK-43780: correlated left join preserves the join predicates") {
544+
// Left outer join preserves both predicates after being decorrelated.
545+
val outerPlan = testRelation
546+
val innerPlan =
547+
Filter(
548+
IsNotNull(c3),
549+
Project(Seq(x, y, a3, b3, c3),
550+
Join(testRelation2, testRelation3, LeftOuter,
551+
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))
552+
553+
val correctAnswer =
554+
Filter(
555+
IsNotNull(c3),
556+
Project(Seq(x, y, a3, b3, c3, b),
557+
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
558+
Some(And(x === a3, b === b3)), JoinHint.NONE)))
559+
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
560+
}
561+
562+
test("SPARK-43780: union all in subquery with correlated join") {
563+
val outerPlan = testRelation
564+
val innerPlan =
565+
Union(
566+
Seq(Project(Seq(x, b3),
567+
Join(testRelation2, testRelation3, Inner,
568+
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)),
569+
Project(Seq(a4, b4),
570+
testRelation4)))
571+
val correctAnswer =
572+
Union(
573+
Seq(Project(Seq(x, b3, a),
574+
Project(Seq(x, b3, a),
575+
Join(
576+
DomainJoin(Seq(a), testRelation2),
577+
testRelation3, Inner,
578+
Some(And(x === a3, y === a)), JoinHint.NONE))),
579+
Project(Seq(a4, b4, a),
580+
DomainJoin(Seq(a),
581+
Project(Seq(a4, b4), testRelation4)))))
582+
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
583+
}
457584
}

sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,72 @@ Project [c1#x, c2#x]
795795
+- LocalRelation [col1#x, col2#x]
796796

797797

798+
-- !query
799+
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
800+
-- !query analysis
801+
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
802+
+- LateralJoin lateral-subquery#x [c1#x], Inner
803+
: +- SubqueryAlias __auto_generated_subquery_name
804+
: +- Project [c1#x, c2#x, c1#x, c2#x]
805+
: +- Join Inner, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
806+
: :- SubqueryAlias spark_catalog.default.t2
807+
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
808+
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
809+
: : +- LocalRelation [col1#x, col2#x]
810+
: +- SubqueryAlias spark_catalog.default.t4
811+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
812+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
813+
: +- LocalRelation [col1#x, col2#x]
814+
+- SubqueryAlias spark_catalog.default.t1
815+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
816+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
817+
+- LocalRelation [col1#x, col2#x]
818+
819+
820+
-- !query
821+
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1)
822+
-- !query analysis
823+
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
824+
+- LateralJoin lateral-subquery#x [c1#x], Inner
825+
: +- SubqueryAlias __auto_generated_subquery_name
826+
: +- Project [c1#x, c2#x, c1#x, c2#x]
827+
: +- Join Inner, (NOT (c1#x = c1#x) AND NOT (c1#x = outer(c1#x)))
828+
: :- SubqueryAlias spark_catalog.default.t2
829+
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
830+
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
831+
: : +- LocalRelation [col1#x, col2#x]
832+
: +- SubqueryAlias spark_catalog.default.t4
833+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
834+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
835+
: +- LocalRelation [col1#x, col2#x]
836+
+- SubqueryAlias spark_catalog.default.t1
837+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
838+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
839+
+- LocalRelation [col1#x, col2#x]
840+
841+
842+
-- !query
843+
SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
844+
-- !query analysis
845+
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
846+
+- LateralJoin lateral-subquery#x [c1#x], LeftOuter
847+
: +- SubqueryAlias __auto_generated_subquery_name
848+
: +- Project [c1#x, c2#x, c1#x, c2#x]
849+
: +- Join LeftOuter, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
850+
: :- SubqueryAlias spark_catalog.default.t4
851+
: : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
852+
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
853+
: : +- LocalRelation [col1#x, col2#x]
854+
: +- SubqueryAlias spark_catalog.default.t2
855+
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
856+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
857+
: +- LocalRelation [col1#x, col2#x]
858+
+- SubqueryAlias spark_catalog.default.t1
859+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
860+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
861+
+- LocalRelation [col1#x, col2#x]
862+
863+
798864
-- !query
799865
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
800866
-- !query analysis

0 commit comments

Comments
 (0)