Skip to content

Commit 0b9da6c

Browse files
authored
make intermediate storage level configurable in CC (graphframes#215)
The default storage level MEMORY_AND_DISK might not be the best option for certain setup. We can make it configurable by users.
1 parent 5faffc5 commit 0b9da6c

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

src/main/scala/org/graphframes/lib/ConnectedComponents.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,30 @@ class ConnectedComponents private[graphframes] (
132132
*/
133133
def getCheckpointInterval: Int = checkpointInterval
134134

135+
private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
136+
137+
/**
138+
* Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``).
139+
*/
140+
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
141+
intermediateStorageLevel = value
142+
this
143+
}
144+
145+
/**
146+
* Gets storage level for intermediate datasets that require multiple passes.
147+
*/
148+
def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel
149+
135150
/**
136151
* Runs the algorithm.
137152
*/
138153
def run(): DataFrame = {
139154
ConnectedComponents.run(graph,
140155
algorithm = algorithm,
141156
broadcastThreshold = broadcastThreshold,
142-
checkpointInterval = checkpointInterval)
157+
checkpointInterval = checkpointInterval,
158+
intermediateStorageLevel = intermediateStorageLevel)
143159
}
144160
}
145161

@@ -256,7 +272,8 @@ object ConnectedComponents extends Logging {
256272
graph: GraphFrame,
257273
algorithm: String,
258274
broadcastThreshold: Int,
259-
checkpointInterval: Int): DataFrame = {
275+
checkpointInterval: Int,
276+
intermediateStorageLevel: StorageLevel): DataFrame = {
260277
require(supportedAlgorithms.contains(algorithm),
261278
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.")
262279

@@ -290,7 +307,7 @@ object ConnectedComponents extends Logging {
290307
logger.info(s"$logPrefix Preparing the graph for connected component computation ...")
291308
val g = prepare(graph)
292309
val vv = g.vertices
293-
var ee = g.edges.persist(StorageLevel.MEMORY_AND_DISK) // src < dst
310+
var ee = g.edges.persist(intermediateStorageLevel) // src < dst
294311
val numEdges = ee.count()
295312
logger.info(s"$logPrefix Found $numEdges edges after preparation.")
296313

@@ -301,17 +318,17 @@ object ConnectedComponents extends Logging {
301318
// large-star step
302319
// compute min neighbors (including self-min)
303320
val minNbrs1 = minNbrs(ee) // src >= min_nbr
304-
.persist(StorageLevel.MEMORY_AND_DISK)
321+
.persist(intermediateStorageLevel)
305322
// connect all strictly larger neighbors to the min neighbor (including self)
306323
ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)
307324
.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst
308325
.distinct()
309-
.persist(StorageLevel.MEMORY_AND_DISK)
326+
.persist(intermediateStorageLevel)
310327

311328
// small-star step
312329
// compute min neighbors (excluding self-min)
313330
val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
314-
.persist(StorageLevel.MEMORY_AND_DISK)
331+
.persist(intermediateStorageLevel)
315332
// connect all smaller neighbors to the min neighbor
316333
ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
317334
.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
@@ -337,7 +354,7 @@ object ConnectedComponents extends Logging {
337354
System.gc() // hint Spark to clean shuffle directories
338355
}
339356

340-
ee.persist(StorageLevel.MEMORY_AND_DISK)
357+
ee.persist(intermediateStorageLevel)
341358

342359
// test convergence
343360

src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.reflect.runtime.universe.TypeTag
2525
import org.apache.spark.sql.{DataFrame, Row}
2626
import org.apache.spark.sql.functions.{col, lit}
2727
import org.apache.spark.sql.types.DataTypes
28+
import org.apache.spark.storage.StorageLevel
2829

2930
import org.graphframes._
3031
import org.graphframes.GraphFrame._
@@ -213,6 +214,22 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
213214
"The result shouldn't depend on checkpoint data if converged before first checkpoint.")
214215
}
215216

217+
test("intermediate storage level") {
218+
val friends = Graphs.friends
219+
val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g"))
220+
221+
val cc = friends.connectedComponents
222+
assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK)
223+
224+
for (storageLevel <- Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_ONLY, StorageLevel.NONE)) {
225+
// TODO: it is not trivial to confirm the actual storage level used
226+
val components = cc
227+
.setIntermediateStorageLevel(storageLevel)
228+
.run()
229+
assertComponents(components, expected)
230+
}
231+
}
232+
216233
private def assertComponents[T: ClassTag:TypeTag](
217234
actual: DataFrame,
218235
expected: Set[Set[T]]): Unit = {

0 commit comments

Comments
 (0)