Skip to content

Commit 5cdb55a

Browse files
committed
[+] compatibility_test: initial version + some test cases
1 parent 96b9882 commit 5cdb55a

File tree

4 files changed

+674
-0
lines changed

4 files changed

+674
-0
lines changed

testscripts/compatibility_cases.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from string import Template
2+
3+
from compatibility_core import Case, LibraryType
4+
5+
6+
LIGHTGBM_VERSIONS = [
7+
'2.3.0',
8+
'2.2.3',
9+
'2.2.2',
10+
'2.2.1',
11+
'2.2.0',
12+
'2.1.2',
13+
'2.1.1',
14+
'2.1.0',
15+
'2.0.12',
16+
'2.0.11',
17+
'2.0.10',
18+
]
19+
20+
XGBOOST_VERSIONS = [
21+
'0.90',
22+
'0.82',
23+
'0.72.1',
24+
]
25+
26+
27+
class BaseCase(Case):
28+
files = dict(
29+
model_filename='model.txt',
30+
true_predictions_filename='true_predictions.txt',
31+
predictions_filename='predictions.txt',
32+
data_filename='data.txt',
33+
)
34+
python_template=None
35+
go_template=None
36+
37+
def compare(self):
38+
self.compare_matrices(
39+
matrix1_filename=self.files['true_predictions_filename'],
40+
matrix2_filename=self.files['predictions_filename'],
41+
tolerance=1e-10,
42+
max_number_of_mismatches_ratio=0.0
43+
)
44+
45+
def go_code(self):
46+
return self.go_template.substitute(self.files)
47+
48+
def python_code(self):
49+
return self.python_template.substitute(self.files)
50+
51+
class LGBaseCase(BaseCase):
52+
library = LibraryType.LIGHTGBM
53+
versions = LIGHTGBM_VERSIONS
54+
55+
56+
class XGBaseCase(BaseCase):
57+
library = LibraryType.XGBOOST
58+
versions = XGBOOST_VERSIONS
59+
60+
61+
class LGBreastCancer(LGBaseCase):
62+
python_template = Template("""
63+
import lightgbm as lgb
64+
import numpy as np
65+
from sklearn import datasets
66+
from sklearn.model_selection import train_test_split
67+
68+
X, y = datasets.load_breast_cancer(return_X_y=True)
69+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
70+
71+
n_estimators = 30
72+
d_train = lgb.Dataset(X_train, label=y_train)
73+
params = {
74+
'boosting_type': 'gbdt',
75+
'objective': 'binary',
76+
}
77+
clf = lgb.train(params, d_train, n_estimators)
78+
y_pred = clf.predict(X_test, raw_score=True)
79+
80+
clf.save_model('$model_filename') # save the model in txt format
81+
np.savetxt('$true_predictions_filename', y_pred)
82+
np.savetxt('$data_filename', X_test, delimiter='\t')
83+
""")
84+
85+
go_template = Template("""
86+
package main
87+
88+
import (
89+
"github.com/dmitryikh/leaves"
90+
"github.com/dmitryikh/leaves/mat"
91+
)
92+
93+
func main() {
94+
test, err := mat.DenseMatFromCsvFile("$data_filename", 0, false, "\t", 0.0)
95+
if err != nil {
96+
panic(err)
97+
}
98+
99+
model, err := leaves.LGEnsembleFromFile("$model_filename", false)
100+
if err != nil {
101+
panic(err)
102+
}
103+
predictions := mat.DenseMatZero(test.Rows, model.NOutputGroups())
104+
err = model.PredictDense(test.Values, test.Rows, test.Cols, predictions.Values, 0, 1)
105+
if err != nil {
106+
panic(err)
107+
}
108+
109+
err = predictions.ToCsvFile("$predictions_filename", "\t")
110+
if err != nil {
111+
panic(err)
112+
}
113+
}
114+
""")
115+
116+
117+
class LGIrisRandomForest(LGBaseCase):
118+
python_template = Template("""
119+
import numpy as np
120+
import pickle
121+
from sklearn import datasets
122+
import lightgbm as lgb
123+
from sklearn.model_selection import train_test_split
124+
125+
126+
data = datasets.load_iris()
127+
X = data['data']
128+
y = data['target']
129+
y[y > 0] = 1
130+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
131+
132+
n_estimators = 30
133+
d_train = lgb.Dataset(X_train, label=y_train)
134+
params = {
135+
'boosting_type': 'rf',
136+
'objective': 'binary',
137+
'bagging_fraction': 0.8,
138+
'feature_fraction': 0.8,
139+
'bagging_freq': 1,
140+
}
141+
142+
clf = lgb.train(params, d_train, n_estimators)
143+
144+
y_pred = clf.predict(X_test)
145+
146+
model_filename = 'lg_rf_iris.model'
147+
pred_filename = 'lg_rf_iris_true_predictions.txt'
148+
# test_filename = 'iris_test.libsvm'
149+
150+
clf.save_model('$model_filename')
151+
np.savetxt('$true_predictions_filename', y_pred)
152+
datasets.dump_svmlight_file(X_test, y_test, '$data_filename')
153+
""")
154+
155+
go_template = Template("""
156+
package main
157+
158+
import (
159+
"github.com/dmitryikh/leaves"
160+
"github.com/dmitryikh/leaves/mat"
161+
)
162+
163+
func main() {
164+
test, err := mat.CSRMatFromLibsvmFile("$data_filename", 0, true)
165+
if err != nil {
166+
panic(err)
167+
}
168+
169+
model, err := leaves.LGEnsembleFromFile("$model_filename", false)
170+
if err != nil {
171+
panic(err)
172+
}
173+
174+
predictions := mat.DenseMatZero(test.Rows(), model.NOutputGroups())
175+
err = model.PredictCSR(test.RowHeaders, test.ColIndexes, test.Values, predictions.Values, 0, 1)
176+
if err != nil {
177+
panic(err)
178+
}
179+
180+
err = predictions.ToCsvFile("$predictions_filename", "\t")
181+
if err != nil {
182+
panic(err)
183+
}
184+
}
185+
""")
186+
187+
188+
class XGIrisMulticlass(XGBaseCase):
189+
python_template = Template("""
190+
import numpy as np
191+
from sklearn import datasets
192+
from sklearn.model_selection import train_test_split
193+
import xgboost as xgb
194+
195+
X, y = datasets.load_iris(return_X_y=True)
196+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
197+
198+
xg_train = xgb.DMatrix(X_train, label=y_train)
199+
xg_test = xgb.DMatrix(X_test, label=y_test)
200+
params = {
201+
'objective': 'multi:softmax',
202+
'num_class': 3,
203+
}
204+
n_estimators = 20
205+
clf = xgb.train(params, xg_train, n_estimators)
206+
y_pred = clf.predict(xg_test, output_margin=True)
207+
# save the model in binary format
208+
clf.save_model('$model_filename')
209+
np.savetxt('$true_predictions_filename', y_pred, delimiter='\t')
210+
datasets.dump_svmlight_file(X_test, y_test, '$data_filename')
211+
""")
212+
213+
go_template = Template("""
214+
package main
215+
216+
import (
217+
"github.com/dmitryikh/leaves"
218+
"github.com/dmitryikh/leaves/mat"
219+
)
220+
221+
func main() {
222+
test, err := mat.CSRMatFromLibsvmFile("$data_filename", 0, true)
223+
if err != nil {
224+
panic(err)
225+
}
226+
227+
model, err := leaves.XGEnsembleFromFile("$model_filename", false)
228+
if err != nil {
229+
panic(err)
230+
}
231+
232+
predictions := mat.DenseMatZero(test.Rows(), model.NOutputGroups())
233+
err = model.PredictCSR(test.RowHeaders, test.ColIndexes, test.Values, predictions.Values, 0, 1)
234+
if err != nil {
235+
panic(err)
236+
}
237+
238+
err = predictions.ToCsvFile("$predictions_filename", "\t")
239+
if err != nil {
240+
panic(err)
241+
}
242+
}
243+
""")
244+
245+
def compare(self):
246+
self.compare_matrices(
247+
matrix1_filename=self.files['true_predictions_filename'],
248+
matrix2_filename=self.files['predictions_filename'],
249+
tolerance=1e-6,
250+
max_number_of_mismatches_ratio=0.0
251+
)
252+
253+
254+
cases = [
255+
LGBreastCancer,
256+
LGIrisRandomForest,
257+
XGIrisMulticlass,
258+
]

0 commit comments

Comments
 (0)