@@ -19,65 +19,88 @@ package org.apache.spark.sql.execution
19
19
20
20
import scala .reflect .runtime .universe .TypeTag
21
21
22
- import org .apache .spark .rdd . RDD
23
- import org .apache .spark .SparkContext
24
-
22
+ import org .apache .spark .{ HashPartitioner , SparkConf , SparkContext }
23
+ import org .apache .spark .rdd .{ RDD , ShuffledRDD }
24
+ import org . apache . spark . sql . catalyst . ScalaReflection
25
25
import org .apache .spark .sql .catalyst .errors ._
26
26
import org .apache .spark .sql .catalyst .expressions ._
27
27
import org .apache .spark .sql .catalyst .plans .physical .{OrderedDistribution , UnspecifiedDistribution }
28
- import org .apache .spark .sql .catalyst .ScalaReflection
28
+ import org .apache .spark .util .MutablePair
29
+
29
30
30
31
case class Project (projectList : Seq [NamedExpression ], child : SparkPlan ) extends UnaryNode {
31
- def output = projectList.map(_.toAttribute)
32
+ override def output = projectList.map(_.toAttribute)
32
33
33
- def execute () = child.execute().mapPartitions { iter =>
34
+ override def execute () = child.execute().mapPartitions { iter =>
34
35
@ transient val reusableProjection = new MutableProjection (projectList)
35
36
iter.map(reusableProjection)
36
37
}
37
38
}
38
39
39
40
case class Filter (condition : Expression , child : SparkPlan ) extends UnaryNode {
40
- def output = child.output
41
+ override def output = child.output
41
42
42
- def execute () = child.execute().mapPartitions { iter =>
43
+ override def execute () = child.execute().mapPartitions { iter =>
43
44
iter.filter(condition.apply(_).asInstanceOf [Boolean ])
44
45
}
45
46
}
46
47
47
48
case class Sample (fraction : Double , withReplacement : Boolean , seed : Int , child : SparkPlan )
48
49
extends UnaryNode {
49
50
50
- def output = child.output
51
+ override def output = child.output
51
52
52
53
// TODO: How to pick seed?
53
- def execute () = child.execute().sample(withReplacement, fraction, seed)
54
+ override def execute () = child.execute().sample(withReplacement, fraction, seed)
54
55
}
55
56
56
57
case class Union (children : Seq [SparkPlan ])(@ transient sc : SparkContext ) extends SparkPlan {
57
58
// TODO: attributes output by union should be distinct for nullability purposes
58
- def output = children.head.output
59
- def execute () = sc.union(children.map(_.execute()))
59
+ override def output = children.head.output
60
+ override def execute () = sc.union(children.map(_.execute()))
60
61
61
62
override def otherCopyArgs = sc :: Nil
62
63
}
63
64
64
- case class StopAfter (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
65
+ /**
66
+ * Take the first limit elements. Note that the implementation is different depending on whether
67
+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
68
+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
69
+ * invoked using execute, we first take the limit on each partition, and then repartition all the
70
+ * data to a single partition to compute the global limit.
71
+ */
72
+ case class Limit (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
73
+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
74
+ // partition local limit -> exchange into one partition -> partition local limit again
75
+
65
76
override def otherCopyArgs = sc :: Nil
66
77
67
- def output = child.output
78
+ override def output = child.output
68
79
69
80
override def executeCollect () = child.execute().map(_.copy()).take(limit)
70
81
71
- // TODO: Terminal split should be implemented differently from non-terminal split.
72
- // TODO: Pick num splits based on |limit|.
73
- def execute () = sc.makeRDD(executeCollect(), 1 )
82
+ override def execute () = {
83
+ val rdd = child.execute().mapPartitions { iter =>
84
+ val mutablePair = new MutablePair [Boolean , Row ]()
85
+ iter.take(limit).map(row => mutablePair.update(false , row))
86
+ }
87
+ val part = new HashPartitioner (1 )
88
+ val shuffled = new ShuffledRDD [Boolean , Row , MutablePair [Boolean , Row ]](rdd, part)
89
+ shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
90
+ shuffled.mapPartitions(_.take(limit).map(_._2))
91
+ }
74
92
}
75
93
76
- case class TopK (limit : Int , sortOrder : Seq [SortOrder ], child : SparkPlan )
77
- (@ transient sc : SparkContext ) extends UnaryNode {
94
+ /**
95
+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
96
+ * having a [[Limit ]] operator after a [[Sort ]] operator. This could have been named TopK, but
97
+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
98
+ */
99
+ case class TakeOrdered (limit : Int , sortOrder : Seq [SortOrder ], child : SparkPlan )
100
+ (@ transient sc : SparkContext ) extends UnaryNode {
78
101
override def otherCopyArgs = sc :: Nil
79
102
80
- def output = child.output
103
+ override def output = child.output
81
104
82
105
@ transient
83
106
lazy val ordering = new RowOrdering (sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
86
109
87
110
// TODO: Terminal split should be implemented differently from non-terminal split.
88
111
// TODO: Pick num splits based on |limit|.
89
- def execute () = sc.makeRDD(executeCollect(), 1 )
112
+ override def execute () = sc.makeRDD(executeCollect(), 1 )
90
113
}
91
114
92
115
@@ -101,15 +124,15 @@ case class Sort(
101
124
@ transient
102
125
lazy val ordering = new RowOrdering (sortOrder)
103
126
104
- def execute () = attachTree(this , " sort" ) {
127
+ override def execute () = attachTree(this , " sort" ) {
105
128
// TODO: Optimize sorting operation?
106
129
child.execute()
107
130
.mapPartitions(
108
131
iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
109
132
preservesPartitioning = true )
110
133
}
111
134
112
- def output = child.output
135
+ override def output = child.output
113
136
}
114
137
115
138
object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
130
153
}
131
154
132
155
case class ExistingRdd (output : Seq [Attribute ], rdd : RDD [Row ]) extends LeafNode {
133
- def execute () = rdd
156
+ override def execute () = rdd
134
157
}
135
158
0 commit comments