Skip to content

Commit f0aabb1

Browse files
committed
Correct semantics for TimeStampedWeakValueHashMap + add tests
This largely accounts for the cases when WeakReference becomes no longer strongly reachable, in which case the map should return None for all get() operations, and should skip the entry for all listing operations.
1 parent 5016375 commit f0aabb1

File tree

4 files changed

+350
-37
lines changed

4 files changed

+350
-37
lines changed

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
7878
*/
7979
protected def assertValid() {
8080
if (!_isValid) {
81-
throw new SparkException("Attempted to use %s when is no longer valid!".format(toString))
81+
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
8282
}
8383
}
8484

core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.Set
2121
import java.util.Map.Entry
2222
import java.util.concurrent.ConcurrentHashMap
2323

24-
import scala.collection.{immutable, JavaConversions, mutable}
24+
import scala.collection.{JavaConversions, mutable}
2525

2626
import org.apache.spark.Logging
2727

@@ -50,11 +50,11 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
5050
}
5151

5252
def iterator: Iterator[(A, B)] = {
53-
val jIterator = getEntrySet.iterator()
53+
val jIterator = getEntrySet.iterator
5454
JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
5555
}
5656

57-
def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet()
57+
def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet
5858

5959
override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
6060
val newMap = new TimeStampedHashMap[A, B1]
@@ -86,8 +86,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
8686
}
8787

8888
override def apply(key: A): B = {
89-
val value = internalMap.get(key)
90-
Option(value).map(_.value).getOrElse { throw new NoSuchElementException() }
89+
get(key).getOrElse { throw new NoSuchElementException() }
9190
}
9291

9392
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
@@ -101,9 +100,9 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
101100
override def size: Int = internalMap.size
102101

103102
override def foreach[U](f: ((A, B)) => U) {
104-
val iterator = getEntrySet.iterator()
105-
while(iterator.hasNext) {
106-
val entry = iterator.next()
103+
val it = getEntrySet.iterator
104+
while(it.hasNext) {
105+
val entry = it.next()
107106
val kv = (entry.getKey, entry.getValue.value)
108107
f(kv)
109108
}
@@ -115,27 +114,39 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
115114
Option(prev).map(_.value)
116115
}
117116

118-
def toMap: immutable.Map[A, B] = iterator.toMap
117+
def putAll(map: Map[A, B]) {
118+
map.foreach { case (k, v) => update(k, v) }
119+
}
120+
121+
def toMap: Map[A, B] = iterator.toMap
119122

120123
def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
121-
val iterator = getEntrySet.iterator()
122-
while (iterator.hasNext) {
123-
val entry = iterator.next()
124+
val it = getEntrySet.iterator
125+
while (it.hasNext) {
126+
val entry = it.next()
124127
if (entry.getValue.timestamp < threshTime) {
125128
f(entry.getKey, entry.getValue.value)
126129
logDebug("Removing key " + entry.getKey)
127-
iterator.remove()
130+
it.remove()
128131
}
129132
}
130133
}
131134

132-
/**
133-
* Removes old key-value pairs that have timestamp earlier than `threshTime`.
134-
*/
135+
/** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
135136
def clearOldValues(threshTime: Long) {
136137
clearOldValues(threshTime, (_, _) => ())
137138
}
138139

139140
private def currentTime: Long = System.currentTimeMillis
140141

142+
// For testing
143+
144+
def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = {
145+
Option(internalMap.get(key))
146+
}
147+
148+
def getTimestamp(key: A): Option[Long] = {
149+
getTimeStampedValue(key).map(_.timestamp)
150+
}
151+
141152
}

core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,61 @@
1818
package org.apache.spark.util
1919

2020
import java.lang.ref.WeakReference
21+
import java.util.concurrent.atomic.AtomicInteger
2122

22-
import scala.collection.{immutable, mutable}
23+
import scala.collection.mutable
24+
25+
import org.apache.spark.Logging
2326

2427
/**
2528
* A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
2629
*
27-
* If the value is garbage collected and the weak reference is null, get() operation returns
28-
* a non-existent value. However, the corresponding key is actually not removed in the current
29-
* implementation. Key-value pairs whose timestamps are older than a particular threshold time
30-
* can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map
31-
* interface to allow it to be a drop-in replacement for Scala HashMaps.
30+
* If the value is garbage collected and the weak reference is null, get() will return a
31+
* non-existent value. These entries are removed from the map periodically (every N inserts), as
32+
* their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
33+
* older than a particular threshold can be removed using the clearOldValues method.
3234
*
33-
* Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe.
35+
* TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
36+
* to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
37+
* so all operations on this HashMap are thread-safe.
3438
*
3539
* @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
3640
*/
3741
private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
38-
extends mutable.Map[A, B]() {
42+
extends mutable.Map[A, B]() with Logging {
3943

4044
import TimeStampedWeakValueHashMap._
4145

4246
private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
47+
private val insertCount = new AtomicInteger(0)
48+
49+
/** Return a map consisting only of entries whose values are still strongly reachable. */
50+
private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
4351

4452
def get(key: A): Option[B] = internalMap.get(key)
4553

46-
def iterator: Iterator[(A, B)] = internalMap.iterator
54+
def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
4755

4856
override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
4957
val newMap = new TimeStampedWeakValueHashMap[A, B1]
58+
val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
59+
newMap.internalMap.putAll(oldMap.toMap)
5060
newMap.internalMap += kv
5161
newMap
5262
}
5363

5464
override def - (key: A): mutable.Map[A, B] = {
5565
val newMap = new TimeStampedWeakValueHashMap[A, B]
66+
newMap.internalMap.putAll(nonNullReferenceMap.toMap)
5667
newMap.internalMap -= key
5768
newMap
5869
}
5970

6071
override def += (kv: (A, B)): this.type = {
6172
internalMap += kv
73+
if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
74+
clearNullValues()
75+
}
6276
this
6377
}
6478

@@ -71,31 +85,53 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo
7185

7286
override def apply(key: A): B = internalMap.apply(key)
7387

74-
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p)
88+
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
7589

7690
override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
7791

7892
override def size: Int = internalMap.size
7993

80-
override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f)
94+
override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
8195

8296
def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
8397

84-
def toMap: immutable.Map[A, B] = iterator.toMap
98+
def toMap: Map[A, B] = iterator.toMap
8599

86-
/**
87-
* Remove old key-value pairs that have timestamp earlier than `threshTime`.
88-
*/
100+
/** Remove old key-value pairs with timestamps earlier than `threshTime`. */
89101
def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
90102

103+
/** Remove entries with values that are no longer strongly reachable. */
104+
def clearNullValues() {
105+
val it = internalMap.getEntrySet.iterator
106+
while (it.hasNext) {
107+
val entry = it.next()
108+
if (entry.getValue.value.get == null) {
109+
logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
110+
it.remove()
111+
}
112+
}
113+
}
114+
115+
// For testing
116+
117+
def getTimestamp(key: A): Option[Long] = {
118+
internalMap.getTimeStampedValue(key).map(_.timestamp)
119+
}
120+
121+
def getReference(key: A): Option[WeakReference[B]] = {
122+
internalMap.getTimeStampedValue(key).map(_.value)
123+
}
91124
}
92125

93126
/**
94127
* Helper methods for converting to and from WeakReferences.
95128
*/
96-
private[spark] object TimeStampedWeakValueHashMap {
129+
private object TimeStampedWeakValueHashMap {
97130

98-
/* Implicit conversion methods to WeakReferences */
131+
// Number of inserts after which entries with null references are removed
132+
val CLEAR_NULL_VALUES_INTERVAL = 100
133+
134+
/* Implicit conversion methods to WeakReferences. */
99135

100136
implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
101137

@@ -107,12 +143,15 @@ private[spark] object TimeStampedWeakValueHashMap {
107143
(kv: (K, WeakReference[V])) => p(kv)
108144
}
109145

110-
/* Implicit conversion methods from WeakReferences */
146+
/* Implicit conversion methods from WeakReferences. */
111147

112148
implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
113149

114150
implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
115-
v.map(fromWeakReference)
151+
v match {
152+
case Some(ref) => Option(fromWeakReference(ref))
153+
case None => None
154+
}
116155
}
117156

118157
implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
@@ -128,5 +167,4 @@ private[spark] object TimeStampedWeakValueHashMap {
128167
map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
129168
mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
130169
}
131-
132170
}

0 commit comments

Comments
 (0)