Skip to content

Commit ed730c9

Browse files
rxinpwendell
authored andcommitted
StopAfter / TopK related changes
1. Renamed StopAfter to Limit to be more consistent with naming in other relational databases. 2. Renamed TopK to TakeOrdered to be more consistent with Spark RDD API. 3. Avoid breaking lineage in Limit. 4. Added a bunch of override's to execution/basicOperators.scala. @marmbrus @liancheng Author: Reynold Xin <[email protected]> Author: Michael Armbrust <[email protected]> Closes #233 from rxin/limit and squashes the following commits: 13eb12a [Reynold Xin] Merge pull request #1 from marmbrus/limit 92b9727 [Michael Armbrust] More hacks to make Maps serialize with Kryo. 4fc8b4e [Reynold Xin] Merge branch 'master' of github.com:apache/spark into limit 87b7d37 [Reynold Xin] Use the proper serializer in limit. 9b79246 [Reynold Xin] Updated doc for Limit. 47d3327 [Reynold Xin] Copy tuples in Limit before shuffle. 231af3a [Reynold Xin] Limit/TakeOrdered: 1. Renamed StopAfter to Limit to be more consistent with naming in other relational databases. 2. Renamed TopK to TakeOrdered to be more consistent with Spark RDD API. 3. Avoid breaking lineage in Limit. 4. Added a bunch of override's to execution/basicOperators.scala.
1 parent 1faa579 commit ed730c9

File tree

8 files changed

+64
-35
lines changed

8 files changed

+64
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class SqlParser extends StandardTokenParsers {
181181
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
182182
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
183183
val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
184-
val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
184+
val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
185185
withLimit
186186
}
187187

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ case class Aggregate(
130130
def references = child.references
131131
}
132132

133-
case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
133+
case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
134134
def output = child.output
135135
def references = limit.references
136136
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
145145
val sparkContext = self.sparkContext
146146

147147
val strategies: Seq[Strategy] =
148-
TopK ::
148+
TakeOrdered ::
149149
PartialAggregation ::
150150
HashJoin ::
151151
ParquetOperations ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
3232
kryo.setRegistrationRequired(false)
3333
kryo.register(classOf[MutablePair[_, _]])
3434
kryo.register(classOf[Array[Any]])
35+
// This is kinda hacky...
3536
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
37+
kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
38+
kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
39+
kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
40+
kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
41+
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
3642
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
3743
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
3844
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
158158
case other => other
159159
}
160160

161-
object TopK extends Strategy {
161+
object TakeOrdered extends Strategy {
162162
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
163-
case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
164-
execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
163+
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
164+
execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
165165
case _ => Nil
166166
}
167167
}
@@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
213213
sparkContext.parallelize(data.map(r =>
214214
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
215215
execution.ExistingRdd(output, dataAsRdd) :: Nil
216-
case logical.StopAfter(IntegerLiteral(limit), child) =>
217-
execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
216+
case logical.Limit(IntegerLiteral(limit), child) =>
217+
execution.Limit(limit, planLater(child))(sparkContext) :: Nil
218218
case Unions(unionChildren) =>
219219
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
220220
case logical.Generate(generator, join, outer, _, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,65 +19,88 @@ package org.apache.spark.sql.execution
1919

2020
import scala.reflect.runtime.universe.TypeTag
2121

22-
import org.apache.spark.rdd.RDD
23-
import org.apache.spark.SparkContext
24-
22+
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
23+
import org.apache.spark.rdd.{RDD, ShuffledRDD}
24+
import org.apache.spark.sql.catalyst.ScalaReflection
2525
import org.apache.spark.sql.catalyst.errors._
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
28-
import org.apache.spark.sql.catalyst.ScalaReflection
28+
import org.apache.spark.util.MutablePair
29+
2930

3031
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
31-
def output = projectList.map(_.toAttribute)
32+
override def output = projectList.map(_.toAttribute)
3233

33-
def execute() = child.execute().mapPartitions { iter =>
34+
override def execute() = child.execute().mapPartitions { iter =>
3435
@transient val reusableProjection = new MutableProjection(projectList)
3536
iter.map(reusableProjection)
3637
}
3738
}
3839

3940
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
40-
def output = child.output
41+
override def output = child.output
4142

42-
def execute() = child.execute().mapPartitions { iter =>
43+
override def execute() = child.execute().mapPartitions { iter =>
4344
iter.filter(condition.apply(_).asInstanceOf[Boolean])
4445
}
4546
}
4647

4748
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
4849
extends UnaryNode {
4950

50-
def output = child.output
51+
override def output = child.output
5152

5253
// TODO: How to pick seed?
53-
def execute() = child.execute().sample(withReplacement, fraction, seed)
54+
override def execute() = child.execute().sample(withReplacement, fraction, seed)
5455
}
5556

5657
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
5758
// TODO: attributes output by union should be distinct for nullability purposes
58-
def output = children.head.output
59-
def execute() = sc.union(children.map(_.execute()))
59+
override def output = children.head.output
60+
override def execute() = sc.union(children.map(_.execute()))
6061

6162
override def otherCopyArgs = sc :: Nil
6263
}
6364

64-
case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
65+
/**
66+
* Take the first limit elements. Note that the implementation is different depending on whether
67+
* this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
68+
* this operator uses Spark's take method on the Spark driver. If it is not terminal or is
69+
* invoked using execute, we first take the limit on each partition, and then repartition all the
70+
* data to a single partition to compute the global limit.
71+
*/
72+
case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
73+
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
74+
// partition local limit -> exchange into one partition -> partition local limit again
75+
6576
override def otherCopyArgs = sc :: Nil
6677

67-
def output = child.output
78+
override def output = child.output
6879

6980
override def executeCollect() = child.execute().map(_.copy()).take(limit)
7081

71-
// TODO: Terminal split should be implemented differently from non-terminal split.
72-
// TODO: Pick num splits based on |limit|.
73-
def execute() = sc.makeRDD(executeCollect(), 1)
82+
override def execute() = {
83+
val rdd = child.execute().mapPartitions { iter =>
84+
val mutablePair = new MutablePair[Boolean, Row]()
85+
iter.take(limit).map(row => mutablePair.update(false, row))
86+
}
87+
val part = new HashPartitioner(1)
88+
val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
89+
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
90+
shuffled.mapPartitions(_.take(limit).map(_._2))
91+
}
7492
}
7593

76-
case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
77-
(@transient sc: SparkContext) extends UnaryNode {
94+
/**
95+
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
96+
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
97+
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
98+
*/
99+
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
100+
(@transient sc: SparkContext) extends UnaryNode {
78101
override def otherCopyArgs = sc :: Nil
79102

80-
def output = child.output
103+
override def output = child.output
81104

82105
@transient
83106
lazy val ordering = new RowOrdering(sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
86109

87110
// TODO: Terminal split should be implemented differently from non-terminal split.
88111
// TODO: Pick num splits based on |limit|.
89-
def execute() = sc.makeRDD(executeCollect(), 1)
112+
override def execute() = sc.makeRDD(executeCollect(), 1)
90113
}
91114

92115

@@ -101,15 +124,15 @@ case class Sort(
101124
@transient
102125
lazy val ordering = new RowOrdering(sortOrder)
103126

104-
def execute() = attachTree(this, "sort") {
127+
override def execute() = attachTree(this, "sort") {
105128
// TODO: Optimize sorting operation?
106129
child.execute()
107130
.mapPartitions(
108131
iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
109132
preservesPartitioning = true)
110133
}
111134

112-
def output = child.output
135+
override def output = child.output
113136
}
114137

115138
object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
130153
}
131154

132155
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
133-
def execute() = rdd
156+
override def execute() = rdd
134157
}
135158

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
188188
val hiveContext = self
189189

190190
override val strategies: Seq[Strategy] = Seq(
191-
TopK,
191+
TakeOrdered,
192192
ParquetOperations,
193193
HiveTableScans,
194194
DataSinks,

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ object HiveQl {
529529

530530
val withLimit =
531531
limitClause.map(l => nodeToExpr(l.getChildren.head))
532-
.map(StopAfter(_, withSort))
532+
.map(Limit(_, withSort))
533533
.getOrElse(withSort)
534534

535535
// TOK_INSERT_INTO means to add files to the table.
@@ -602,7 +602,7 @@ object HiveQl {
602602
case Token("TOK_TABLESPLITSAMPLE",
603603
Token("TOK_ROWCOUNT", Nil) ::
604604
Token(count, Nil) :: Nil) =>
605-
StopAfter(Literal(count.toInt), relation)
605+
Limit(Literal(count.toInt), relation)
606606
case Token("TOK_TABLESPLITSAMPLE",
607607
Token("TOK_PERCENT", Nil) ::
608608
Token(fraction, Nil) :: Nil) =>

0 commit comments

Comments
 (0)