Skip to content

Commit d9de9c7

Browse files
committed
[SYSTEMML-995] Determine Frame Format If Needed
If `MLContextConversionUtil.dataFrameToFrameObject` receives a DataFrame *without* any frame metadata, a new `FrameMetadata` will be created with an empty `FrameFormat`, and so the subsequent `isDataFrameWithIDColumn` function will always return false ([line 360](https://github.com/apache/incubator-systemml/blob/master/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java#L360). We should just create a new function similar to `determineMatrixFormatIfNeeded` for frames, and call it before the `isDataFrameWithIDColumn` function, as is done for DataFrame-matrix conversions ([line 412](https://github.com/apache/incubator-systemml/blob/master/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java#L412).
1 parent 85a8746 commit d9de9c7

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ public static FrameObject dataFrameToFrameObject(String variableName, DataFrame
357357
//setup meta data and java spark context
358358
if (frameMetadata == null)
359359
frameMetadata = new FrameMetadata();
360+
determineFrameFormatIfNeeded(dataFrame, frameMetadata);
360361
boolean containsID = isDataFrameWithIDColumn(frameMetadata);
361362
JavaSparkContext javaSparkContext = MLContextUtil
362363
.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
@@ -490,6 +491,33 @@ public static void determineMatrixFormatIfNeeded(DataFrame dataFrame, MatrixMeta
490491
matrixMetadata.setMatrixFormat(mf);
491492
}
492493

494+
/**
495+
* If the FrameFormat of the DataFrame has not been explicitly specified,
496+
* attempt to determine the proper FrameFormat.
497+
*
498+
* @param dataFrame
499+
* the Spark {@code DataFrame}
500+
* @param frameMetadata
501+
* the frame metadata, if available
502+
*/
503+
public static void determineFrameFormatIfNeeded(DataFrame dataFrame, FrameMetadata frameMetadata) {
504+
FrameFormat frameFormat = frameMetadata.getFrameFormat();
505+
if (frameFormat != null) {
506+
return;
507+
}
508+
509+
StructType schema = dataFrame.schema();
510+
boolean hasID = false;
511+
try {
512+
schema.fieldIndex(RDDConverterUtils.DF_ID_COLUMN);
513+
hasID = true;
514+
} catch (IllegalArgumentException iae) {
515+
}
516+
517+
FrameFormat ff = hasID ? FrameFormat.DF_WITH_INDEX : FrameFormat.DF;
518+
frameMetadata.setFrameFormat(ff);
519+
}
520+
493521
/**
494522
* Return whether or not the DataFrame has an ID column.
495523
*

src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,24 +271,33 @@ private void testDataFrameScriptInput(ValueType[] schema, boolean containsID, bo
271271

272272
//create input data frame
273273
DataFrame df = createDataFrame(sqlctx, mbA, containsID, schema);
274-
FrameMetadata meta = new FrameMetadata(containsID ? FrameFormat.DF_WITH_INDEX :
274+
275+
// Create full frame metadata, and empty frame metadata
276+
FrameMetadata meta = new FrameMetadata(containsID ? FrameFormat.DF_WITH_INDEX :
275277
FrameFormat.DF, mc2.getRows(), mc2.getCols());
276-
278+
FrameMetadata metaEmpty = new FrameMetadata();
279+
277280
//create mlcontext
278281
ml = new MLContext(sc);
279282
ml.setExplain(true);
280283

281-
//run script and obtain result
284+
//run scripts and obtain result
282285
Script script1 = dml(
283-
//"Xf = read($Xffile); Xm = as.matrix(Xf); write(Xm, $Xmfile);")
284286
"Xm = as.matrix(Xf);")
285287
.in("Xf", df, meta).out("Xm");
286-
Matrix Xm = ml.execute(script1).getMatrix("Xm");
287-
MatrixBlock mbB = Xm.toBinaryBlockMatrix().getMatrixBlock();
288-
288+
Script script2 = dml(
289+
"Xm = as.matrix(Xf);")
290+
.in("Xf", df, metaEmpty).out("Xm"); // empty metadata
291+
Matrix Xm1 = ml.execute(script1).getMatrix("Xm");
292+
Matrix Xm2 = ml.execute(script2).getMatrix("Xm");
293+
MatrixBlock mbB1 = Xm1.toBinaryBlockMatrix().getMatrixBlock();
294+
MatrixBlock mbB2 = Xm2.toBinaryBlockMatrix().getMatrixBlock();
295+
289296
//compare frame blocks
290-
double[][] B = DataConverter.convertToDoubleMatrix(mbB);
291-
TestUtils.compareMatrices(A, B, rows1, cols, eps);
297+
double[][] B1 = DataConverter.convertToDoubleMatrix(mbB1);
298+
double[][] B2 = DataConverter.convertToDoubleMatrix(mbB2);
299+
TestUtils.compareMatrices(A, B1, rows1, cols, eps);
300+
TestUtils.compareMatrices(A, B2, rows1, cols, eps);
292301
}
293302
catch( Exception ex ) {
294303
ex.printStackTrace();
@@ -364,4 +373,4 @@ private DataFrame createDataFrame(SQLContext sqlctx, MatrixBlock mb, boolean con
364373
JavaRDD<Row> rowRDD = sc.parallelize(list);
365374
return sqlctx.createDataFrame(rowRDD, dfSchema);
366375
}
367-
}
376+
}

0 commit comments

Comments
 (0)