@@ -132,14 +132,30 @@ class ConnectedComponents private[graphframes] (
132
132
*/
133
133
def getCheckpointInterval : Int = checkpointInterval
134
134
135
+ private var intermediateStorageLevel : StorageLevel = StorageLevel .MEMORY_AND_DISK
136
+
137
+ /**
138
+ * Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``).
139
+ */
140
+ def setIntermediateStorageLevel (value : StorageLevel ): this .type = {
141
+ intermediateStorageLevel = value
142
+ this
143
+ }
144
+
145
+ /**
146
+ * Gets storage level for intermediate datasets that require multiple passes.
147
+ */
148
+ def getIntermediateStorageLevel : StorageLevel = intermediateStorageLevel
149
+
135
150
/**
136
151
* Runs the algorithm.
137
152
*/
138
153
def run (): DataFrame = {
139
154
ConnectedComponents .run(graph,
140
155
algorithm = algorithm,
141
156
broadcastThreshold = broadcastThreshold,
142
- checkpointInterval = checkpointInterval)
157
+ checkpointInterval = checkpointInterval,
158
+ intermediateStorageLevel = intermediateStorageLevel)
143
159
}
144
160
}
145
161
@@ -256,7 +272,8 @@ object ConnectedComponents extends Logging {
256
272
graph : GraphFrame ,
257
273
algorithm : String ,
258
274
broadcastThreshold : Int ,
259
- checkpointInterval : Int ): DataFrame = {
275
+ checkpointInterval : Int ,
276
+ intermediateStorageLevel : StorageLevel ): DataFrame = {
260
277
require(supportedAlgorithms.contains(algorithm),
261
278
s " Supported algorithms are { ${supportedAlgorithms.mkString(" , " )}}, but got $algorithm. " )
262
279
@@ -290,7 +307,7 @@ object ConnectedComponents extends Logging {
290
307
logger.info(s " $logPrefix Preparing the graph for connected component computation ... " )
291
308
val g = prepare(graph)
292
309
val vv = g.vertices
293
- var ee = g.edges.persist(StorageLevel . MEMORY_AND_DISK ) // src < dst
310
+ var ee = g.edges.persist(intermediateStorageLevel ) // src < dst
294
311
val numEdges = ee.count()
295
312
logger.info(s " $logPrefix Found $numEdges edges after preparation. " )
296
313
@@ -301,17 +318,17 @@ object ConnectedComponents extends Logging {
301
318
// large-star step
302
319
// compute min neighbors (including self-min)
303
320
val minNbrs1 = minNbrs(ee) // src >= min_nbr
304
- .persist(StorageLevel . MEMORY_AND_DISK )
321
+ .persist(intermediateStorageLevel )
305
322
// connect all strictly larger neighbors to the min neighbor (including self)
306
323
ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)
307
324
.select(col(DST ).as(SRC ), col(MIN_NBR ).as(DST )) // src > dst
308
325
.distinct()
309
- .persist(StorageLevel . MEMORY_AND_DISK )
326
+ .persist(intermediateStorageLevel )
310
327
311
328
// small-star step
312
329
// compute min neighbors (excluding self-min)
313
330
val minNbrs2 = ee.groupBy(col(SRC )).agg(min(col(DST )).as(MIN_NBR ), count(" *" ).as(CNT )) // src > min_nbr
314
- .persist(StorageLevel . MEMORY_AND_DISK )
331
+ .persist(intermediateStorageLevel )
315
332
// connect all smaller neighbors to the min neighbor
316
333
ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
317
334
.select(col(MIN_NBR ).as(SRC ), col(DST )) // src <= dst
@@ -337,7 +354,7 @@ object ConnectedComponents extends Logging {
337
354
System .gc() // hint Spark to clean shuffle directories
338
355
}
339
356
340
- ee.persist(StorageLevel . MEMORY_AND_DISK )
357
+ ee.persist(intermediateStorageLevel )
341
358
342
359
// test convergence
343
360
0 commit comments