Skip to content

Commit 85bfcc7

Browse files
zaleslawybabak
authored andcommitted
IGNITE-10803: [ML] Add prototype LogReg loading from PMML format
This closes apache#5744
1 parent e539dfc commit 85bfcc7

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

examples/pom.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,31 @@
122122
<version>${javassist.version}</version>
123123
<scope>test</scope>
124124
</dependency>
125+
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-model -->
126+
<dependency>
127+
<groupId>org.jpmml</groupId>
128+
<artifactId>pmml-model</artifactId>
129+
<version>1.4.7</version>
130+
</dependency>
131+
132+
<dependency>
133+
<groupId>com.fasterxml.jackson.core</groupId>
134+
<artifactId>jackson-core</artifactId>
135+
<version>2.7.3</version>
136+
</dependency>
137+
138+
<dependency>
139+
<groupId>com.fasterxml.jackson.core</groupId>
140+
<artifactId>jackson-databind</artifactId>
141+
<version>2.7.3</version>
142+
</dependency>
143+
144+
<dependency>
145+
<groupId>com.fasterxml.jackson.core</groupId>
146+
<artifactId>jackson-annotations</artifactId>
147+
<version>2.7.3</version>
148+
</dependency>
149+
125150
</dependencies>
126151

127152
<properties>
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.ignite.examples.ml.inference;
19+
20+
import java.io.File;
21+
import java.io.FileInputStream;
22+
import java.io.FileNotFoundException;
23+
import java.io.IOException;
24+
import java.io.InputStream;
25+
import javax.xml.bind.JAXBException;
26+
import org.apache.ignite.Ignite;
27+
import org.apache.ignite.IgniteCache;
28+
import org.apache.ignite.Ignition;
29+
import org.apache.ignite.ml.math.primitives.vector.Vector;
30+
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
31+
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
32+
import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
33+
import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
34+
import org.apache.ignite.ml.util.MLSandboxDatasets;
35+
import org.apache.ignite.ml.util.SandboxMLCache;
36+
import org.dmg.pmml.PMML;
37+
import org.dmg.pmml.regression.RegressionModel;
38+
import org.dmg.pmml.regression.RegressionTable;
39+
import org.jpmml.model.PMMLUtil;
40+
import org.xml.sax.SAXException;
41+
42+
/**
43+
* Run logistic regression model loaded from PMML file. The PMML file was generated by Spark MLLib toPMML operator.
44+
* <p>
45+
* Code in this example launches Ignite grid and fills the cache with test data points (based on the
46+
* <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
47+
* <p>
48+
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
49+
*/
50+
public class LogRegFromSparkThroughPMMLExample {
51+
/** Run example. */
52+
public static void main(String[] args) throws FileNotFoundException {
53+
System.out.println();
54+
System.out.println(">>> Logistic regression model loaded from PMML over partitioned dataset usage example started.");
55+
// Start ignite grid.
56+
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
57+
System.out.println(">>> Ignite grid started.");
58+
59+
IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
60+
.fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
61+
62+
LogisticRegressionModel mdl = PMMLParser.load("examples/src/main/resources/models/spark/iris.pmml");
63+
64+
System.out.println(">>> Logistic regression model: " + mdl);
65+
66+
double accuracy = BinaryClassificationEvaluator.evaluate(
67+
dataCache,
68+
mdl,
69+
(k, v) -> v.copyOfRange(1, v.size()),
70+
(k, v) -> v.get(0),
71+
new Accuracy<>()
72+
);
73+
74+
System.out.println("\n>>> Accuracy " + accuracy);
75+
System.out.println("\n>>> Test Error " + (1 - accuracy));
76+
}
77+
}
78+
79+
/** Util class to build the LogReg model. */
80+
private static class PMMLParser {
81+
/**
82+
* @param path Path.
83+
*/
84+
public static LogisticRegressionModel load(String path) {
85+
try (InputStream is = new FileInputStream(new File(path))) {
86+
PMML pmml = PMMLUtil.unmarshal(is);
87+
88+
RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0);
89+
90+
RegressionTable regTbl = logRegMdl.getRegressionTables().get(0);
91+
92+
Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size());
93+
94+
for (int i = 0; i < regTbl.getNumericPredictors().size(); i++)
95+
coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient());
96+
97+
double interceptor = regTbl.getIntercept();
98+
99+
return new LogisticRegressionModel(coefficients, interceptor);
100+
}
101+
catch (IOException | JAXBException | SAXException e) {
102+
e.printStackTrace();
103+
}
104+
105+
return null;
106+
}
107+
}
108+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
2+
<PMML xmlns="http://www.dmg.org/PMML-4_2" version="4.2">
3+
<Header description="logistic regression">
4+
<Application name="Apache Spark MLlib" version="2.2.0"/>
5+
<Timestamp>2018-12-25T15:09:09</Timestamp>
6+
</Header>
7+
<DataDictionary numberOfFields="5">
8+
<DataField name="field_0" optype="continuous" dataType="double"/>
9+
<DataField name="field_1" optype="continuous" dataType="double"/>
10+
<DataField name="field_2" optype="continuous" dataType="double"/>
11+
<DataField name="field_3" optype="continuous" dataType="double"/>
12+
<DataField name="target" optype="categorical" dataType="string"/>
13+
</DataDictionary>
14+
<RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
15+
<MiningSchema>
16+
<MiningField name="field_0" usageType="active"/>
17+
<MiningField name="field_1" usageType="active"/>
18+
<MiningField name="field_2" usageType="active"/>
19+
<MiningField name="field_3" usageType="active"/>
20+
<MiningField name="target" usageType="target"/>
21+
</MiningSchema>
22+
<RegressionTable intercept="0.0" targetCategory="1">
23+
<NumericPredictor name="field_0" coefficient="5.84520630732407"/>
24+
<NumericPredictor name="field_1" coefficient="-19.36222130270906"/>
25+
<NumericPredictor name="field_2" coefficient="5.66074235971065"/>
26+
<NumericPredictor name="field_3" coefficient="16.110585062151788"/>
27+
</RegressionTable>
28+
<RegressionTable intercept="-0.0" targetCategory="0"/>
29+
</RegressionModel>
30+
</PMML>

0 commit comments

Comments
 (0)