Skip to content

Commit a7260d3

Browse files
committed
Added try-catch in context cleaner and null value cleaning in TimeStampedWeakValueHashMap.
1 parent e61daa0 commit a7260d3

File tree

3 files changed

+64
-34
lines changed

3 files changed

+64
-34
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
5050
/** Start the cleaner */
5151
def start() {
5252
cleaningThread.setDaemon(true)
53+
cleaningThread.setName("ContextCleaner")
5354
cleaningThread.start()
5455
}
5556

@@ -60,7 +61,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
6061
}
6162

6263
/**
63-
* Clean (unpersist) RDD data. Do not perform any time or resource intensive
64+
* Clean RDD data. Do not perform any time or resource intensive
6465
* computation in this function as this is called from a finalize() function.
6566
*/
6667
def cleanRDD(rddId: Int) {
@@ -92,39 +93,48 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
9293

9394
/** Keep cleaning RDDs and shuffle data */
9495
private def keepCleaning() {
95-
try {
96-
while (!isStopped) {
96+
while (!isStopped) {
97+
try {
9798
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
98-
taskOpt.foreach(task => {
99+
taskOpt.foreach { task =>
99100
logDebug("Got cleaning task " + taskOpt.get)
100101
task match {
101-
case CleanRDD(rddId) => doCleanRDD(sc, rddId)
102+
case CleanRDD(rddId) => doCleanRDD(rddId)
102103
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
103104
}
104-
})
105+
}
106+
} catch {
107+
case ie: InterruptedException =>
108+
if (!isStopped) logWarning("Cleaning thread interrupted")
109+
case t: Throwable => logError("Error in cleaning thread", t)
105110
}
106-
} catch {
107-
case ie: InterruptedException =>
108-
if (!isStopped) logWarning("Cleaning thread interrupted")
109111
}
110112
}
111113

112114
/** Perform RDD cleaning */
113-
private def doCleanRDD(sc: SparkContext, rddId: Int) {
114-
logDebug("Cleaning rdd " + rddId)
115-
blockManagerMaster.removeRdd(rddId, false)
116-
sc.persistentRdds.remove(rddId)
117-
listeners.foreach(_.rddCleaned(rddId))
118-
logInfo("Cleaned rdd " + rddId)
115+
private def doCleanRDD(rddId: Int) {
116+
try {
117+
logDebug("Cleaning RDD " + rddId)
118+
blockManagerMaster.removeRdd(rddId, false)
119+
sc.persistentRdds.remove(rddId)
120+
listeners.foreach(_.rddCleaned(rddId))
121+
logInfo("Cleaned RDD " + rddId)
122+
} catch {
123+
case t: Throwable => logError("Error cleaning RDD " + rddId, t)
124+
}
119125
}
120126

121127
/** Perform shuffle cleaning */
122128
private def doCleanShuffle(shuffleId: Int) {
123-
logDebug("Cleaning shuffle " + shuffleId)
124-
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
125-
blockManagerMaster.removeShuffle(shuffleId)
126-
listeners.foreach(_.shuffleCleaned(shuffleId))
127-
logInfo("Cleaned shuffle " + shuffleId)
129+
try {
130+
logDebug("Cleaning shuffle " + shuffleId)
131+
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
132+
blockManagerMaster.removeShuffle(shuffleId)
133+
listeners.foreach(_.shuffleCleaned(shuffleId))
134+
logInfo("Cleaned shuffle " + shuffleId)
135+
} catch {
136+
case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t)
137+
}
128138
}
129139

130140
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark
2020
import java.io._
2121
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2222

23-
import scala.Some
2423
import scala.collection.mutable.{HashSet, Map}
2524
import scala.concurrent.Await
2625

core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.lang.ref.WeakReference
2424
import java.util.concurrent.ConcurrentHashMap
2525

2626
import org.apache.spark.Logging
27+
import java.util.concurrent.atomic.AtomicInteger
2728

2829
private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) {
2930
def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value))
@@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea
4445
private[spark] class TimeStampedWeakValueHashMap[A, B]()
4546
extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging {
4647

48+
/** Number of inserts after which keys whose weak ref values are null will be cleaned */
49+
private val CLEANUP_INTERVAL = 1000
50+
51+
/** Counter for counting the number of inserts */
52+
private val insertCounts = new AtomicInteger(0)
53+
4754
protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = {
4855
new ConcurrentHashMap[A, TimeStampedWeakValue[B]]()
4956
}
@@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
5259
new TimeStampedWeakValueHashMap[K1, V1]()
5360
}
5461

62+
override def +=(kv: (A, B)): this.type = {
63+
// Cleanup null value at certain intervals
64+
if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) {
65+
cleanNullValues()
66+
}
67+
super.+=(kv)
68+
}
69+
5570
override def get(key: A): Option[B] = {
5671
Option(internalJavaMap.get(key)) match {
5772
case Some(weakValue) =>
5873
val value = weakValue.weakValue.get
59-
if (value == null) cleanupKey(key)
74+
if (value == null) {
75+
internalJavaMap.remove(key)
76+
}
6077
Option(value)
6178
case None =>
6279
None
@@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
7289
}
7390

7491
override def iterator: Iterator[(A, B)] = {
75-
val jIterator = internalJavaMap.entrySet().iterator()
76-
JavaConversions.asScalaIterator(jIterator).flatMap(kv => {
77-
val key = kv.getKey
78-
val value = kv.getValue.weakValue.get
79-
if (value == null) {
80-
cleanupKey(key)
81-
Seq.empty
82-
} else {
83-
Seq((key, value))
84-
}
92+
val iterator = internalJavaMap.entrySet().iterator()
93+
JavaConversions.asScalaIterator(iterator).flatMap(kv => {
94+
val (key, value) = (kv.getKey, kv.getValue.weakValue.get)
95+
if (value != null) Seq((key, value)) else Seq.empty
8596
})
8697
}
8798

@@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
104115
}
105116
}
106117

107-
private def cleanupKey(key: A) {
108-
// TODO: Consider cleaning up keys to empty weak ref values automatically in future.
118+
/**
119+
* Removes keys whose weak referenced values have become null.
120+
*/
121+
private def cleanNullValues() {
122+
val iterator = internalJavaMap.entrySet().iterator()
123+
while (iterator.hasNext) {
124+
val entry = iterator.next()
125+
if (entry.getValue.weakValue.get == null) {
126+
logDebug("Removing key " + entry.getKey)
127+
iterator.remove()
128+
}
129+
}
109130
}
110131

111132
private def currentTime = System.currentTimeMillis()

0 commit comments

Comments
 (0)