Skip to content

Commit 4740506

Browse files
committed
[SYSTEMML-1656] Fix BLAS integration (corrupted matrix block apis)
The dispatching between operations over uncompressed or compressed matrix blocks is realized via late binding. The recently added BLAS integration introduced additional matrix block APIs without overriding them for compressed matrix blocks. This corrupted, for example, matrix-vector operations over compressed matrices as they are mistakenly routed to uncompressed operations. This patch fixes this issue my removing these unnecessary API extensions and simplifying the CP aggregate binary instruction to avoid the impression that all compressed matrices are handled through the vector-matrix branch.
1 parent ae71d00 commit 4740506

File tree

5 files changed

+275
-40
lines changed

5 files changed

+275
-40
lines changed

src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
3232
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
3333
import org.apache.sysml.runtime.matrix.operators.Operator;
34-
import org.apache.sysml.utils.NativeHelper;
3534

3635
public class AggregateBinaryCPInstruction extends BinaryCPInstruction
3736
{
@@ -72,20 +71,16 @@ public void processInstruction(ExecutionContext ec)
7271
{
7372
//get inputs
7473
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
75-
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
74+
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
75+
76+
//compute matrix multiplication
77+
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
78+
MatrixBlock main = (matBlock2 instanceof CompressedMatrixBlock) ? matBlock2 : matBlock1;
79+
MatrixBlock ret = (MatrixBlock) main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
7680

77-
//compute matrix multiplication
78-
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
79-
MatrixBlock soresBlock = null;
80-
if( matBlock2 instanceof CompressedMatrixBlock )
81-
soresBlock = (MatrixBlock) (matBlock2.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op));
82-
else {
83-
soresBlock = (MatrixBlock) (matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, NativeHelper.isNativeLibraryLoaded()));
84-
}
85-
8681
//release inputs/outputs
8782
ec.releaseMatrixInput(input1.getName());
8883
ec.releaseMatrixInput(input2.getName());
89-
ec.setMatrixOutput(output.getName(), soresBlock);
84+
ec.setMatrixOutput(output.getName(), ret);
9085
}
9186
}

src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4894,20 +4894,10 @@ private double sumWeightForQuantile()
48944894
public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value,
48954895
MatrixValue result, AggregateBinaryOperator op ) throws DMLRuntimeException
48964896
{
4897-
return aggregateBinaryOperations(m1Value, m2Value, result, op, NativeHelper.isNativeLibraryLoaded());
4897+
return aggregateBinaryOperations(m1Value, m2Value, result, op);
48984898
}
48994899

4900-
public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value,
4901-
MatrixValue result, AggregateBinaryOperator op, boolean enableNativeBLAS ) throws DMLRuntimeException
4902-
{
4903-
return aggregateBinaryOperations(m1Value, m2Value, result, op, enableNativeBLAS);
4904-
}
4905-
4906-
public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2Value, MatrixValue result, AggregateBinaryOperator op) throws DMLRuntimeException {
4907-
return aggregateBinaryOperations(m1Value, m2Value, result, op, NativeHelper.isNativeLibraryLoaded());
4908-
}
4909-
4910-
public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2Value, MatrixValue result, AggregateBinaryOperator op, boolean nativeMatMult)
4900+
public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2Value, MatrixValue result, AggregateBinaryOperator op)
49114901
throws DMLRuntimeException
49124902
{
49134903
//check input types, dimensions, configuration
@@ -4933,7 +4923,7 @@ public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2
49334923
ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
49344924

49354925
//compute matrix multiplication (only supported binary aggregate operation)
4936-
if( nativeMatMult )
4926+
if( NativeHelper.isNativeLibraryLoaded() )
49374927
LibMatrixNative.matrixMult(m1, m2, ret, op.getNumThreads());
49384928
else if( op.getNumThreads() > 1 )
49394929
LibMatrixMult.matrixMult(m1, m2, ret, op.getNumThreads());
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysml.test.integration.functions.compress;
21+
22+
import java.io.File;
23+
import java.util.HashMap;
24+
25+
import org.apache.sysml.api.DMLScript;
26+
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
27+
import org.apache.sysml.lops.LopProperties.ExecType;
28+
import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
29+
import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
30+
import org.apache.sysml.test.integration.AutomatedTestBase;
31+
import org.apache.sysml.test.integration.TestConfiguration;
32+
import org.apache.sysml.test.utils.TestUtils;
33+
import org.junit.Test;
34+
35+
/**
36+
*
37+
*/
38+
public class CompressedL2SVM extends AutomatedTestBase
39+
{
40+
private final static String TEST_NAME1 = "L2SVM";
41+
private final static String TEST_DIR = "functions/compress/";
42+
private final static String TEST_CONF = "SystemML-config-compress.xml";
43+
private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF);
44+
45+
private final static double eps = 1e-4;
46+
47+
private final static int rows = 1468;
48+
private final static int cols = 980;
49+
50+
private final static double sparsity1 = 0.7; //dense
51+
private final static double sparsity2 = 0.1; //sparse
52+
53+
private final static int intercept = 0;
54+
private final static double epsilon = 0.000000001;
55+
private final static double maxiter = 10;
56+
57+
@Override
58+
public void setUp() {
59+
TestUtils.clearAssertionInformation();
60+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, TEST_NAME1, new String[] { "w" }));
61+
}
62+
63+
@Test
64+
public void testL2SVMDenseCP() {
65+
runL2SVMTest(TEST_NAME1, false, ExecType.CP);
66+
}
67+
68+
@Test
69+
public void testL2SVMSparseCP() {
70+
runL2SVMTest(TEST_NAME1, true, ExecType.CP);
71+
}
72+
73+
@Test
74+
public void testL2SVMDenseSP() {
75+
runL2SVMTest(TEST_NAME1, false, ExecType.SPARK);
76+
}
77+
78+
@Test
79+
public void testL2SVMSparseSP() {
80+
runL2SVMTest(TEST_NAME1, true, ExecType.SPARK);
81+
}
82+
83+
/**
84+
*
85+
* @param sparseM1
86+
* @param sparseM2
87+
* @param instType
88+
*/
89+
private void runL2SVMTest( String testname,boolean sparse, ExecType instType)
90+
{
91+
//rtplatform for MR
92+
RUNTIME_PLATFORM platformOld = rtplatform;
93+
switch( instType ){
94+
case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
95+
case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
96+
default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
97+
}
98+
99+
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
100+
if( rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK || rtplatform == RUNTIME_PLATFORM.SPARK )
101+
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
102+
103+
try
104+
{
105+
String TEST_NAME = testname;
106+
TestConfiguration config = getTestConfiguration(TEST_NAME);
107+
loadTestConfiguration(config);
108+
109+
fullDMLScriptName = "scripts/algorithms/l2-svm.dml";
110+
programArgs = new String[]{ "-explain", "-stats", "-nvargs", "X="+input("X"), "Y="+input("Y"),
111+
"icpt="+String.valueOf(intercept), "tol="+String.valueOf(epsilon), "reg=0.001",
112+
"maxiter="+String.valueOf(maxiter), "model="+output("w"), "Log= "};
113+
114+
rCmd = getRCmd(inputDir(), String.valueOf(intercept),String.valueOf(epsilon),
115+
String.valueOf(maxiter), expectedDir());
116+
117+
//generate actual datasets
118+
double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 714);
119+
writeInputMatrixWithMTD("X", X, true);
120+
double[][] y = TestUtils.round(getRandomMatrix(rows, 1, 0, 1, 1.0, 136));
121+
writeInputMatrixWithMTD("Y", y, true);
122+
123+
runTest(true, false, null, -1);
124+
runRScript(true);
125+
126+
//compare matrices
127+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("w");
128+
HashMap<CellIndex, Double> rfile = readRMatrixFromFS("w");
129+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
130+
}
131+
finally {
132+
rtplatform = platformOld;
133+
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
134+
CompressedMatrixBlock.ALLOW_DDC_ENCODING = true;
135+
}
136+
}
137+
138+
/**
139+
* Override default configuration with custom test configuration to ensure
140+
* scratch space and local temporary directory locations are also updated.
141+
*/
142+
@Override
143+
protected File getConfigTemplateFile() {
144+
// Instrumentation in this test's output log to show custom configuration file used for template.
145+
System.out.println("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
146+
return TEST_CONF_FILE;
147+
}
148+
}

src/test/java/org/apache/sysml/test/integration/functions/compress/CompressedLinregCG.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,32 +62,26 @@ public void setUp() {
6262
}
6363

6464
@Test
65-
public void testGDFOLinregCGDenseCP() {
66-
runGDFOTest(TEST_NAME1, false, ExecType.CP);
65+
public void testLinregCGDenseCP() {
66+
runLinregCGTest(TEST_NAME1, false, ExecType.CP);
6767
}
6868

6969
@Test
70-
public void testGDFOLinregCGSparseCP() {
71-
runGDFOTest(TEST_NAME1, true, ExecType.CP);
70+
public void testLinregCGSparseCP() {
71+
runLinregCGTest(TEST_NAME1, true, ExecType.CP);
7272
}
7373

7474
@Test
75-
public void testGDFOLinregCGDenseSP() {
76-
runGDFOTest(TEST_NAME1, false, ExecType.SPARK);
75+
public void testLinregCGDenseSP() {
76+
runLinregCGTest(TEST_NAME1, false, ExecType.SPARK);
7777
}
7878

7979
@Test
80-
public void testGDFOLinregCGSparseSP() {
81-
runGDFOTest(TEST_NAME1, true, ExecType.SPARK);
80+
public void testLinregCGSparseSP() {
81+
runLinregCGTest(TEST_NAME1, true, ExecType.SPARK);
8282
}
8383

84-
/**
85-
*
86-
* @param sparseM1
87-
* @param sparseM2
88-
* @param instType
89-
*/
90-
private void runGDFOTest( String testname,boolean sparse, ExecType instType)
84+
private void runLinregCGTest( String testname,boolean sparse, ExecType instType)
9185
{
9286
//rtplatform for MR
9387
RUNTIME_PLATFORM platformOld = rtplatform;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args <- commandArgs(TRUE)
23+
library("Matrix")
24+
25+
X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")));
26+
Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")));
27+
intercept = as.integer(args[2]);
28+
epsilon = as.double(args[3]);
29+
lambda = 0.001;
30+
maxiterations = as.integer(args[4]);
31+
32+
check_min = min(Y)
33+
check_max = max(Y)
34+
num_min = sum(Y == check_min)
35+
num_max = sum(Y == check_max)
36+
if(num_min + num_max != nrow(Y)){
37+
print("please check Y, it should contain only 2 labels")
38+
}else{
39+
if(check_min != -1 | check_max != +1)
40+
Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - check_min)
41+
}
42+
43+
dimensions = ncol(X)
44+
45+
if (intercept == 1) {
46+
ones = matrix(1, rows=num_samples, cols=1)
47+
X = cbind(X, ones);
48+
}
49+
50+
num_rows_in_w = dimensions
51+
if(intercept == 1){
52+
num_rows_in_w = num_rows_in_w + 1
53+
}
54+
w = matrix(0, num_rows_in_w, 1)
55+
56+
g_old = t(X) %*% Y
57+
s = g_old
58+
59+
Xw = matrix(0,nrow(X),1)
60+
iter = 0
61+
positive_label = check_max
62+
negative_label = check_min
63+
64+
continue = TRUE
65+
while(continue && iter < maxiterations){
66+
t = 0
67+
Xd = X %*% s
68+
wd = lambda * sum(w * s)
69+
dd = lambda * sum(s * s)
70+
continue1 = TRUE
71+
while(continue1){
72+
tmp_Xw = Xw + t*Xd
73+
out = 1 - Y * (tmp_Xw)
74+
sv = which(out > 0)
75+
g = wd + t*dd - sum(out[sv] * Y[sv] * Xd[sv])
76+
h = dd + sum(Xd[sv] * Xd[sv])
77+
t = t - g/h
78+
continue1 = (g*g/h >= 1e-10)
79+
}
80+
81+
w = w + t*s
82+
Xw = Xw + t*Xd
83+
84+
out = 1 - Y * (X %*% w)
85+
sv = which(out > 0)
86+
obj = 0.5 * sum(out[sv] * out[sv]) + lambda/2 * sum(w * w)
87+
g_new = t(X[sv,]) %*% (out[sv] * Y[sv]) - lambda * w
88+
89+
print(paste("OBJ : ", obj))
90+
91+
continue = (t*sum(s * g_old) >= epsilon*obj)
92+
93+
be = sum(g_new * g_new)/sum(g_old * g_old)
94+
s = be * s + g_new
95+
g_old = g_new
96+
97+
iter = iter + 1
98+
}
99+
100+
extra_model_params = matrix(0, 4, 1)
101+
extra_model_params[1,1] = positive_label
102+
extra_model_params[2,1] = negative_label
103+
extra_model_params[3,1] = intercept
104+
extra_model_params[4,1] = dimensions
105+
106+
w = t(cbind(t(w), t(extra_model_params)))
107+
108+
writeMM(as(w,"CsparseMatrix"), paste(args[5], "w", sep=""));

0 commit comments

Comments
 (0)