17
17
18
18
package org .apache .spark .sql .catalyst .expressions
19
19
20
- import org .apache .spark .sql .catalyst .rules ._
21
-
22
20
/**
23
21
* Rewrites an expression using rules that are guaranteed preserve the result while attempting
24
22
* to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
@@ -30,52 +28,49 @@ import org.apache.spark.sql.catalyst.rules._
30
28
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType ]]s are stripped.
31
29
* - Commutative and associative operations ([[Add ]] and [[Multiply ]]) have their children ordered
32
30
* by `hashCode`.
33
- * - [[EqualTo ]] and [[EqualNullSafe ]] are reordered by `hashCode`.
31
+ * - [[EqualTo ]] and [[EqualNullSafe ]] are reordered by `hashCode`.
34
32
* - Other comparisons ([[GreaterThan ]], [[LessThan ]]) are reversed by `hashCode`.
35
33
*/
36
- object Canonicalize extends RuleExecutor [Expression ] {
37
- override protected def batches : Seq [Batch ] =
38
- Batch (
39
- " Expression Canonicalization" , FixedPoint (100 ),
40
- IgnoreNamesTypes ,
41
- Reorder ) :: Nil
34
+ object Canonicalize extends {
35
+ def execute (e : Expression ): Expression = {
36
+ expressionReorder(ignoreNamesTypes(e))
37
+ }
42
38
43
39
/** Remove names and nullability from types. */
44
- protected object IgnoreNamesTypes extends Rule [Expression ] {
45
- override def apply (e : Expression ): Expression = e transformUp {
46
- case a : AttributeReference =>
47
- AttributeReference (" none" , a.dataType.asNullable)(exprId = a.exprId)
48
- }
40
+ private def ignoreNamesTypes (e : Expression ): Expression = e match {
41
+ case a : AttributeReference =>
42
+ AttributeReference (" none" , a.dataType.asNullable)(exprId = a.exprId)
43
+ case _ => e
49
44
}
50
45
51
46
/** Collects adjacent commutative operations. */
52
- protected def gatherCommutative (
47
+ private def gatherCommutative (
53
48
e : Expression ,
54
49
f : PartialFunction [Expression , Seq [Expression ]]): Seq [Expression ] = e match {
55
50
case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
56
51
case other => other :: Nil
57
52
}
58
53
59
54
/** Orders a set of commutative operations by their hash code. */
60
- protected def orderCommutative (
55
+ private def orderCommutative (
61
56
e : Expression ,
62
57
f : PartialFunction [Expression , Seq [Expression ]]): Seq [Expression ] =
63
58
gatherCommutative(e, f).sortBy(_.hashCode())
64
59
65
60
/** Rearrange expressions that are commutative or associative. */
66
- protected object Reorder extends Rule [Expression ] {
67
- override def apply (e : Expression ): Expression = e transformUp {
68
- case a : Add => orderCommutative(a, { case Add (l, r) => Seq (l, r) }).reduce(Add )
69
- case m : Multiply => orderCommutative(m, { case Multiply (l, r) => Seq (l, r) }).reduce(Multiply )
61
+ private def expressionReorder (e : Expression ): Expression = e match {
62
+ case a : Add => orderCommutative(a, { case Add (l, r) => Seq (l, r) }).reduce(Add )
63
+ case m : Multiply => orderCommutative(m, { case Multiply (l, r) => Seq (l, r) }).reduce(Multiply )
64
+
65
+ case EqualTo (l, r) if l.hashCode() > r.hashCode() => EqualTo (r, l)
66
+ case EqualNullSafe (l, r) if l.hashCode() > r.hashCode() => EqualNullSafe (r, l)
70
67
71
- case EqualTo (l, r) if l.hashCode() > r.hashCode() => EqualTo (r, l)
72
- case EqualNullSafe (l, r) if l.hashCode() > r.hashCode() => EqualNullSafe (r, l)
68
+ case GreaterThan (l, r) if l.hashCode() > r.hashCode() => LessThan (r, l)
69
+ case LessThan (l, r) if l.hashCode() > r.hashCode() => GreaterThan (r, l)
73
70
74
- case GreaterThan (l, r) if l.hashCode() > r.hashCode() => LessThan (r, l)
75
- case LessThan (l, r) if l.hashCode() > r.hashCode() => GreaterThan (r, l)
71
+ case GreaterThanOrEqual (l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual (r, l)
72
+ case LessThanOrEqual (l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual (r, l)
76
73
77
- case GreaterThanOrEqual (l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual (r, l)
78
- case LessThanOrEqual (l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual (r, l)
79
- }
74
+ case _ => e
80
75
}
81
76
}
0 commit comments