@@ -24,6 +24,7 @@ import java.lang.ref.WeakReference
24
24
import java .util .concurrent .ConcurrentHashMap
25
25
26
26
import org .apache .spark .Logging
27
+ import java .util .concurrent .atomic .AtomicInteger
27
28
28
29
private [util] case class TimeStampedWeakValue [T ](timestamp : Long , weakValue : WeakReference [T ]) {
29
30
def this (timestamp : Long , value : T ) = this (timestamp, new WeakReference [T ](value))
@@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea
44
45
private [spark] class TimeStampedWeakValueHashMap [A , B ]()
45
46
extends WrappedJavaHashMap [A , B , A , TimeStampedWeakValue [B ]] with Logging {
46
47
48
+ /** Number of inserts after which keys whose weak ref values are null will be cleaned */
49
+ private val CLEANUP_INTERVAL = 1000
50
+
51
+ /** Counter for counting the number of inserts */
52
+ private val insertCounts = new AtomicInteger (0 )
53
+
47
54
protected [util] val internalJavaMap : util.Map [A , TimeStampedWeakValue [B ]] = {
48
55
new ConcurrentHashMap [A , TimeStampedWeakValue [B ]]()
49
56
}
@@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
52
59
new TimeStampedWeakValueHashMap [K1 , V1 ]()
53
60
}
54
61
62
+ override def += (kv : (A , B )): this .type = {
63
+ // Cleanup null value at certain intervals
64
+ if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0 ) {
65
+ cleanNullValues()
66
+ }
67
+ super .+= (kv)
68
+ }
69
+
55
70
override def get (key : A ): Option [B ] = {
56
71
Option (internalJavaMap.get(key)) match {
57
72
case Some (weakValue) =>
58
73
val value = weakValue.weakValue.get
59
- if (value == null ) cleanupKey(key)
74
+ if (value == null ) {
75
+ internalJavaMap.remove(key)
76
+ }
60
77
Option (value)
61
78
case None =>
62
79
None
@@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
72
89
}
73
90
74
91
override def iterator : Iterator [(A , B )] = {
75
- val jIterator = internalJavaMap.entrySet().iterator()
76
- JavaConversions .asScalaIterator(jIterator).flatMap(kv => {
77
- val key = kv.getKey
78
- val value = kv.getValue.weakValue.get
79
- if (value == null ) {
80
- cleanupKey(key)
81
- Seq .empty
82
- } else {
83
- Seq ((key, value))
84
- }
92
+ val iterator = internalJavaMap.entrySet().iterator()
93
+ JavaConversions .asScalaIterator(iterator).flatMap(kv => {
94
+ val (key, value) = (kv.getKey, kv.getValue.weakValue.get)
95
+ if (value != null ) Seq ((key, value)) else Seq .empty
85
96
})
86
97
}
87
98
@@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
104
115
}
105
116
}
106
117
107
- private def cleanupKey (key : A ) {
108
- // TODO: Consider cleaning up keys to empty weak ref values automatically in future.
118
+ /**
119
+ * Removes keys whose weak referenced values have become null.
120
+ */
121
+ private def cleanNullValues () {
122
+ val iterator = internalJavaMap.entrySet().iterator()
123
+ while (iterator.hasNext) {
124
+ val entry = iterator.next()
125
+ if (entry.getValue.weakValue.get == null ) {
126
+ logDebug(" Removing key " + entry.getKey)
127
+ iterator.remove()
128
+ }
129
+ }
109
130
}
110
131
111
132
private def currentTime = System .currentTimeMillis()
0 commit comments