Skip to content

Commit c16738d

Browse files
committed
[SYSTEMML-2218] Improved spark mapmm (avoid parallelize-repartition)
This patch improves the spark mapmm instruction (broadcast-based matrix multiply) by avoiding unnecessary shuffle for repartitioning - which is used to guarantee output partition size - if the input is a parallelized RDD. For this scenario, we now create the parallelized RDD with right number of partitions.
1 parent 094f555 commit c16738d

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,13 @@ public static boolean isLocalMaster() {
291291
@SuppressWarnings("unchecked")
292292
public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockRDDHandleForVariable( String varname ) {
293293
return (JavaPairRDD<MatrixIndexes,MatrixBlock>)
294-
getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo);
294+
getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo, -1);
295+
}
296+
297+
@SuppressWarnings("unchecked")
298+
public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockRDDHandleForVariable( String varname, int numParts ) {
299+
return (JavaPairRDD<MatrixIndexes,MatrixBlock>)
300+
getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo, numParts);
295301
}
296302

297303
/**
@@ -304,15 +310,19 @@ public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockRDDHandleForVariable
304310
@SuppressWarnings("unchecked")
305311
public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockRDDHandleForVariable( String varname ) {
306312
JavaPairRDD<Long,FrameBlock> out = (JavaPairRDD<Long,FrameBlock>)
307-
getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo);
313+
getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo, -1);
308314
return out;
309315
}
310316

311317
public JavaPairRDD<?,?> getRDDHandleForVariable( String varname, InputInfo inputInfo ) {
318+
return getRDDHandleForVariable(varname, inputInfo, -1);
319+
}
320+
321+
public JavaPairRDD<?,?> getRDDHandleForVariable( String varname, InputInfo inputInfo, int numParts ) {
312322
Data dat = getVariable(varname);
313323
if( dat instanceof MatrixObject ) {
314324
MatrixObject mo = getMatrixObject(varname);
315-
return getRDDHandleForMatrixObject(mo, inputInfo);
325+
return getRDDHandleForMatrixObject(mo, inputInfo, numParts);
316326
}
317327
else if( dat instanceof FrameObject ) {
318328
FrameObject fo = getFrameObject(varname);
@@ -323,16 +333,12 @@ else if( dat instanceof FrameObject ) {
323333
}
324334
}
325335

326-
/**
327-
* This call returns an RDD handle for a given matrix object. This includes
328-
* the creation of RDDs for in-memory or binary-block HDFS data.
329-
*
330-
* @param mo matrix object
331-
* @param inputInfo input info
332-
* @return JavaPairRDD handle for a matrix object
333-
*/
334-
@SuppressWarnings("unchecked")
335336
public JavaPairRDD<?,?> getRDDHandleForMatrixObject( MatrixObject mo, InputInfo inputInfo ) {
337+
return getRDDHandleForMatrixObject(mo, inputInfo, -1);
338+
}
339+
340+
@SuppressWarnings("unchecked")
341+
public JavaPairRDD<?,?> getRDDHandleForMatrixObject( MatrixObject mo, InputInfo inputInfo, int numParts ) {
336342
//NOTE: MB this logic should be integrated into MatrixObject
337343
//However, for now we cannot assume that spark libraries are
338344
//always available and hence only store generic references in
@@ -366,7 +372,7 @@ else if( mo.isDirty() || mo.isCached(false) )
366372
}
367373
else { //default case
368374
MatrixBlock mb = mo.acquireRead(); //pin matrix in memory
369-
rdd = toMatrixJavaPairRDD(sc, mb, (int)mo.getNumRowsPerBlock(), (int)mo.getNumColumnsPerBlock());
375+
rdd = toMatrixJavaPairRDD(sc, mb, (int)mo.getNumRowsPerBlock(), (int)mo.getNumColumnsPerBlock(), numParts);
370376
mo.release(); //unpin matrix
371377
_parRDDs.registerRDD(rdd.id(), OptimizerUtils.estimatePartitionedSizeExactSparsity(mc), true);
372378
}
@@ -657,16 +663,11 @@ public void setRDDHandleForVariable(String varname, JavaPairRDD<?,?> rdd) {
657663
obj.setRDDHandle( rddhandle );
658664
}
659665

660-
/**
661-
* Utility method for creating an RDD out of an in-memory matrix block.
662-
*
663-
* @param sc java spark context
664-
* @param src matrix block
665-
* @param brlen block row length
666-
* @param bclen block column length
667-
* @return JavaPairRDD handle to matrix block
668-
*/
669666
public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen) {
667+
return toMatrixJavaPairRDD(sc, src, brlen, bclen, -1);
668+
}
669+
670+
public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen, int numParts) {
670671
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
671672
List<Tuple2<MatrixIndexes,MatrixBlock>> list = null;
672673

@@ -681,7 +682,9 @@ public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSpa
681682
.collect(Collectors.toList());
682683
}
683684

684-
JavaPairRDD<MatrixIndexes,MatrixBlock> result = sc.parallelizePairs(list);
685+
JavaPairRDD<MatrixIndexes,MatrixBlock> result = (numParts > 1) ?
686+
sc.parallelizePairs(list, numParts) : sc.parallelizePairs(list);
687+
685688
if (DMLScript.STATISTICS) {
686689
Statistics.accSparkParallelizeTime(System.nanoTime() - t0);
687690
Statistics.incSparkParallelizeCount(1);

src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,11 @@ public void processInstruction(ExecutionContext ec) {
9797
MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar);
9898
MatrixCharacteristics mcBc = sec.getMatrixCharacteristics(bcastVar);
9999

100-
//get input rdd
101-
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
100+
//get input rdd with preferred number of partitions to avoid unnecessary repartition
101+
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar,
102+
(requiresFlatMapFunction(type, mcBc) && requiresRepartitioning(
103+
type, mcRdd, mcBc, sec.getSparkContext().defaultParallelism())) ?
104+
getNumRepartitioning(type, mcRdd, mcBc) : -1);
102105

103106
//investigate if a repartitioning - including a potential flip of broadcast and rdd
104107
//inputs - is required to ensure moderately sized output partitions (2GB limitation)
@@ -216,7 +219,8 @@ private static boolean requiresRepartitioning( CacheType type, MatrixCharacteris
216219
boolean isLargeOutput = (OptimizerUtils.estimatePartitionedSizeExactSparsity(isLeft?mcBc.getRows():mcRdd.getRows(),
217220
isLeft?mcRdd.getCols():mcBc.getCols(), isLeft?mcBc.getRowsPerBlock():mcRdd.getRowsPerBlock(),
218221
isLeft?mcRdd.getColsPerBlock():mcBc.getColsPerBlock(), 1.0) / numPartitions) > 1024*1024*1024;
219-
return isOuter && isLargeOutput && mcRdd.dimsKnown() && mcBc.dimsKnown();
222+
return isOuter && isLargeOutput && mcRdd.dimsKnown() && mcBc.dimsKnown()
223+
&& numPartitions < getNumRepartitioning(type, mcRdd, mcBc);
220224
}
221225

222226
/**

0 commit comments

Comments
 (0)