Skip to content

Commit ec0eaed

Browse files
committed
add lightgbm regressor.
1 parent 3048dba commit ec0eaed

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

sklearn_model_selection_lightgbm.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import lightgbm as lgb
4+
5+
from sklearn import model_selection
6+
7+
from sklearn import ensemble
8+
from sklearn import datasets
9+
10+
from sklearn.utils import shuffle
11+
12+
X = list('aabcccdddd')
13+
print(X)
14+
k_fold = model_selection.KFold(n_splits=10)
15+
for train_indices, test_indices in k_fold.split(X):
16+
print('Train: %s | Test: %s' % (
17+
train_indices, test_indices))
18+
19+
boston = datasets.load_boston()
20+
X, y = shuffle(boston.data, boston.target, random_state=13)
21+
X = X.astype(np.float32)
22+
y = y.astype(np.float32)
23+
24+
params = {
25+
'n_estimators': 200,
26+
'max_depth':4,
27+
'min_samples_split': 2,
28+
'learning_rate':0.01,
29+
'loss': 'ls'
30+
}
31+
32+
models = []
33+
for i in [100, 200, 300, 400, 500, 600, 700]:
34+
params.update(n_estimators=i)
35+
models.append([str(i), ensemble.GradientBoostingRegressor(**params)])
36+
for i in [100, 200, 300, 400, 500, 600, 700]:
37+
models.append(['g' + str(i), lgb.LGBMRegressor(
38+
objective='regression', num_leaves=31,learning_rate=0.05,n_estimators=i)])
39+
40+
seed = 13
41+
scoring = 'neg_mean_squared_log_error'
42+
results = []
43+
names = []
44+
for name, model in models:
45+
kfold = model_selection.KFold(n_splits=10, random_state=seed)
46+
cv_results = model_selection.cross_val_score(
47+
model, X, y, cv=kfold, scoring=scoring)
48+
results.append(cv_results)
49+
names.append(name)
50+
msg = "%s: %f (%f), %f %f" % (name,
51+
cv_results.mean(), cv_results.std(),
52+
cv_results.min(), cv_results.max())
53+
print(msg)
54+
55+
fig = plt.figure()
56+
fig.suptitle('Algorithm Comparison')
57+
ax = fig.add_subplot(111)
58+
plt.boxplot(results)
59+
ax.set_xticklabels(names)
60+
plt.show()

0 commit comments

Comments
 (0)