Skip to content

Commit b515edd

Browse files
committed
[*] NClasses -> NRawOutputGroups
1 parent c602564 commit b515edd

11 files changed

+87
-85
lines changed

doc.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ predict_breast_cancer_model.go:
6161
}
6262
fmt.Printf("Name: %s\n", model.Name())
6363
fmt.Printf("NFeatures: %d\n", model.NFeatures())
64-
fmt.Printf("NClasses: %d\n", model.NClasses())
64+
fmt.Printf("NRawOutputGroups: %d\n", model.NRawOutputGroups())
6565
fmt.Printf("NEstimators: %d\n", model.NEstimators())
6666
6767
// loading true predictions as DenseMat
@@ -71,7 +71,7 @@ predict_breast_cancer_model.go:
7171
}
7272
7373
// preallocate slice to store model predictions
74-
predictions := make([]float64, test.Rows*model.NClasses())
74+
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
7575
// do predictions
7676
model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
7777
// compare results
@@ -86,7 +86,7 @@ Output:
8686
8787
Name: lightgbm.gbdt
8888
NFeatures: 30
89-
NClasses: 1
89+
NRawOutputGroups: 1
9090
NEstimators: 30
9191
Predictions the same!
9292
@@ -146,7 +146,7 @@ predict_iris_model.go:
146146
}
147147
fmt.Printf("Name: %s\n", model.Name())
148148
fmt.Printf("NFeatures: %d\n", model.NFeatures())
149-
fmt.Printf("NClasses: %d\n", model.NClasses())
149+
fmt.Printf("NRawOutputGroups: %d\n", model.NRawOutputGroups())
150150
fmt.Printf("NEstimators: %d\n", model.NEstimators())
151151
152152
// loading true predictions as DenseMat
@@ -156,7 +156,7 @@ predict_iris_model.go:
156156
}
157157
158158
// preallocate slice to store model predictions
159-
predictions := make([]float64, csr.Rows()*model.NClasses())
159+
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
160160
// do predictions
161161
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, 1)
162162
// compare results
@@ -177,7 +177,7 @@ Output:
177177
178178
Name: xgboost.gbtree
179179
NFeatures: 4
180-
NClasses: 3
180+
NRawOutputGroups: 3
181181
NEstimators: 5
182182
Predictions the same! (mismatch = 0)
183183

leaves.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const BatchSize = 16
1212

1313
type ensembleBaseInterface interface {
1414
NEstimators() int
15-
NClasses() int
15+
NRawOutputGroups() int
1616
NFeatures() int
1717
Name() string
1818
adjustNEstimators(nEstimators int) int
@@ -32,7 +32,7 @@ type Ensemble struct {
3232
// function transformation and etc)
3333
// NOTE: for multiclass prediction use Predict
3434
func (e *Ensemble) PredictSingle(fvals []float64, nEstimators int) float64 {
35-
if e.NClasses() != 1 {
35+
if e.NRawOutputGroups() != 1 {
3636
return 0.0
3737
}
3838
if e.NFeatures() > len(fvals) {
@@ -51,8 +51,8 @@ func (e *Ensemble) PredictSingle(fvals []float64, nEstimators int) float64 {
5151
// NOTE: for single class predictions one can use simplified function PredictSingle
5252
func (e *Ensemble) Predict(fvals []float64, nEstimators int, predictions []float64) error {
5353
nRows := 1
54-
if len(predictions) < e.NClasses()*nRows {
55-
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NClasses()*nRows)
54+
if len(predictions) < e.NRawOutputGroups()*nRows {
55+
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NRawOutputGroups()*nRows)
5656
}
5757
if e.NFeatures() > len(fvals) {
5858
return fmt.Errorf("incorrect number of features (%d)", len(fvals))
@@ -72,8 +72,8 @@ func (e *Ensemble) Predict(fvals []float64, nEstimators int, predictions []float
7272
// Note, `predictions` slice should be properly allocated on call side
7373
func (e *Ensemble) PredictCSR(indptr []int, cols []int, vals []float64, predictions []float64, nEstimators int, nThreads int) error {
7474
nRows := len(indptr) - 1
75-
if len(predictions) < e.NClasses()*nRows {
76-
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NClasses()*nRows)
75+
if len(predictions) < e.NRawOutputGroups()*nRows {
76+
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NRawOutputGroups()*nRows)
7777
}
7878
nEstimators = e.adjustNEstimators(nEstimators)
7979
if nRows <= BatchSize || nThreads == 0 || nThreads == 1 {
@@ -136,7 +136,7 @@ func (e *Ensemble) predictCSRInner(
136136
fvals[cols[j]] = vals[j]
137137
}
138138
}
139-
e.predictInner(fvals, nEstimators, predictions, i*e.NClasses())
139+
e.predictInner(fvals, nEstimators, predictions, i*e.NRawOutputGroups())
140140
e.resetFVals(fvals)
141141
}
142142
}
@@ -156,8 +156,8 @@ func (e *Ensemble) PredictDense(
156156
nThreads int,
157157
) error {
158158
nRows := nrows
159-
if len(predictions) < e.NClasses()*nRows {
160-
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NClasses()*nRows)
159+
if len(predictions) < e.NRawOutputGroups()*nRows {
160+
return fmt.Errorf("predictions slice too short (should be at least %d)", e.NRawOutputGroups()*nRows)
161161
}
162162
if ncols == 0 || e.NFeatures() > ncols {
163163
return fmt.Errorf("incorrect number of columns")
@@ -166,7 +166,7 @@ func (e *Ensemble) PredictDense(
166166
if nRows <= BatchSize || nThreads == 0 || nThreads == 1 {
167167
// single thread calculations
168168
for i := 0; i < nRows; i++ {
169-
e.predictInner(vals[i*ncols:(i+1)*ncols], nEstimators, predictions, i*e.NClasses())
169+
e.predictInner(vals[i*ncols:(i+1)*ncols], nEstimators, predictions, i*e.NRawOutputGroups())
170170
}
171171
return nil
172172
}
@@ -190,7 +190,7 @@ func (e *Ensemble) PredictDense(
190190
endIndex = nRows
191191
}
192192
for i := startIndex; i < endIndex; i++ {
193-
e.predictInner(vals[i*int(ncols):(i+1)*int(ncols)], nEstimators, predictions, i*e.NClasses())
193+
e.predictInner(vals[i*int(ncols):(i+1)*int(ncols)], nEstimators, predictions, i*e.NRawOutputGroups())
194194
}
195195
}
196196
}()
@@ -205,14 +205,16 @@ func (e *Ensemble) PredictDense(
205205
return nil
206206
}
207207

208-
// NEstimators returns number of estimators (trees) in ensemble (per class)
208+
// NEstimators returns number of estimators (trees) in ensemble (per group)
209209
func (e *Ensemble) NEstimators() int {
210210
return e.ensembleBaseInterface.NEstimators()
211211
}
212212

213-
// NClasses returns number of classes to predict
214-
func (e *Ensemble) NClasses() int {
215-
return e.ensembleBaseInterface.NClasses()
213+
// NRawOutputGroups returns number of groups (numbers) in every object
214+
// predictions. For example binary logistic model will give 1, but 4-class
215+
// prediction model will give 4 numbers per object
216+
func (e *Ensemble) NRawOutputGroups() int {
217+
return e.ensembleBaseInterface.NRawOutputGroups()
216218
}
217219

218220
// NFeatures returns number of features in the model

leaves_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ func InnerTestLGMulticlass(t *testing.T, nThreads int) {
459459
if model.NEstimators() != 10 {
460460
t.Fatalf("expected 10 trees (got %d)", model.NEstimators())
461461
}
462-
if model.NClasses() != 5 {
463-
t.Fatalf("expected 5 classes (got %d)", model.NClasses())
462+
if model.NRawOutputGroups() != 5 {
463+
t.Fatalf("expected 5 classes (got %d)", model.NRawOutputGroups())
464464
}
465465

466466
// loading true predictions as DenseMat
@@ -470,7 +470,7 @@ func InnerTestLGMulticlass(t *testing.T, nThreads int) {
470470
}
471471

472472
// do predictions
473-
predictions := make([]float64, dense.Rows*model.NClasses())
473+
predictions := make([]float64, dense.Rows*model.NRawOutputGroups())
474474
model.PredictDense(dense.Values, dense.Rows, dense.Cols, predictions, 0, nThreads)
475475
// compare results
476476
const tolerance = 1e-7
@@ -481,12 +481,12 @@ func InnerTestLGMulticlass(t *testing.T, nThreads int) {
481481
// check Predict
482482
singleIdx := 200
483483
fvals := dense.Values[singleIdx*dense.Cols : (singleIdx+1)*dense.Cols]
484-
predictions = make([]float64, model.NClasses())
484+
predictions = make([]float64, model.NRawOutputGroups())
485485
err = model.Predict(fvals, 0, predictions)
486486
if err != nil {
487487
t.Errorf("error while call model.Predict: %s", err.Error())
488488
}
489-
if err := util.AlmostEqualFloat64Slices(truePredictions.Values[singleIdx*model.NClasses():(singleIdx+1)*model.NClasses()], predictions, tolerance); err != nil {
489+
if err := util.AlmostEqualFloat64Slices(truePredictions.Values[singleIdx*model.NRawOutputGroups():(singleIdx+1)*model.NRawOutputGroups()], predictions, tolerance); err != nil {
490490
t.Errorf("different Predict prediction: %s", err.Error())
491491
}
492492
}
@@ -531,7 +531,7 @@ func InnerTestXGDermatology(t *testing.T, nThreads int) {
531531
}
532532

533533
// do predictions
534-
predictions := make([]float64, csr.Rows()*model.NClasses())
534+
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
535535
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, nThreads)
536536
// compare results
537537
const tolerance = 1e-6
@@ -563,7 +563,7 @@ func TestSKGradientBoostingClassifier(t *testing.T) {
563563
}
564564

565565
// do predictions
566-
predictions := make([]float64, csr.Rows()*model.NClasses())
566+
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
567567
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, 1)
568568
// compare results
569569
const tolerance = 1e-6
@@ -597,7 +597,7 @@ func TestSKIris(t *testing.T) {
597597
}
598598

599599
// do predictions
600-
predictions := make([]float64, csr.Rows()*model.NClasses())
600+
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
601601
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, 1)
602602
// compare results
603603
const tolerance = 1e-6
@@ -639,7 +639,7 @@ func TestLGRandomForestIris(t *testing.T) {
639639
}
640640

641641
// do predictions
642-
predictions := make([]float64, csr.Rows()*model.NClasses())
642+
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
643643
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, 1)
644644
// compare results
645645
const tolerance = 1e-6
@@ -714,7 +714,7 @@ func TestLGDARTBreastCancer(t *testing.T) {
714714
}
715715

716716
// do predictions
717-
predictions := make([]float64, test.Rows*model.NClasses())
717+
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
718718
err = model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
719719
if err != nil {
720720
t.Fatal(err)
@@ -753,7 +753,7 @@ func TestLGKDDCup99(t *testing.T) {
753753
}
754754

755755
// do predictions
756-
predictions := make([]float64, test.Rows*model.NClasses())
756+
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
757757
err = model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
758758
if err != nil {
759759
t.Fatal(err)
@@ -790,7 +790,7 @@ func InnerBenchmarkLGKDDCup99(b *testing.B, nThreads int) {
790790

791791
// do benchmark
792792
b.ResetTimer()
793-
predictions := make([]float64, test.Rows*model.NClasses())
793+
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
794794
for i := 0; i < b.N; i++ {
795795
model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, nThreads)
796796
}
@@ -826,7 +826,7 @@ func TestLGJsonBreastCancer(t *testing.T) {
826826
}
827827

828828
// do predictions
829-
predictions := make([]float64, test.Rows*model.NClasses())
829+
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
830830
err = model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
831831
if err != nil {
832832
t.Fatal(err)

lgensemble.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import (
66

77
// lgEnsemble is LightGBM model (ensemble of trees)
88
type lgEnsemble struct {
9-
Trees []lgTree
10-
MaxFeatureIdx int
11-
nClasses int
9+
Trees []lgTree
10+
MaxFeatureIdx int
11+
nRawOutputGroups int
1212
// lgEnsemble suits for different models from different packages (ex., LightGBM gbrt & sklearn gbrt)
1313
// name contains the origin of the model
1414
name string
@@ -19,11 +19,11 @@ type lgEnsemble struct {
1919
}
2020

2121
func (e *lgEnsemble) NEstimators() int {
22-
return len(e.Trees) / e.nClasses
22+
return len(e.Trees) / e.nRawOutputGroups
2323
}
2424

25-
func (e *lgEnsemble) NClasses() int {
26-
return e.nClasses
25+
func (e *lgEnsemble) NRawOutputGroups() int {
26+
return e.nRawOutputGroups
2727
}
2828

2929
func (e *lgEnsemble) NFeatures() int {
@@ -38,16 +38,16 @@ func (e *lgEnsemble) Name() string {
3838
}
3939

4040
func (e *lgEnsemble) predictInner(fvals []float64, nEstimators int, predictions []float64, startIndex int) {
41-
for k := 0; k < e.nClasses; k++ {
41+
for k := 0; k < e.nRawOutputGroups; k++ {
4242
predictions[startIndex+k] = 0.0
4343
}
4444
coef := 1.0
4545
if e.averageOutput {
4646
coef = 1.0 / float64(nEstimators)
4747
}
4848
for i := 0; i < nEstimators; i++ {
49-
for k := 0; k < e.nClasses; k++ {
50-
predictions[startIndex+k] += e.Trees[i*e.nClasses+k].predict(fvals) * coef
49+
for k := 0; k < e.nRawOutputGroups; k++ {
50+
predictions[startIndex+k] += e.Trees[i*e.nRawOutputGroups+k].predict(fvals) * coef
5151
}
5252
}
5353
}

lgensemble_io.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func LGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem
289289
} else if nTreePerIteration < 1 {
290290
return nil, fmt.Errorf("num_tree_per_iteration (%d) should be > 0", nTreePerIteration)
291291
}
292-
e.nClasses = nClasses
292+
e.nRawOutputGroups = nClasses
293293

294294
maxFeatureIdx, err := params.ToInt("max_feature_idx")
295295
if err != nil {
@@ -312,8 +312,8 @@ func LGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem
312312
nTrees := len(treeSizes)
313313
if nTrees == 0 {
314314
return nil, fmt.Errorf("no trees in file (based on tree_sizes value)")
315-
} else if nTrees%e.nClasses != 0 {
316-
return nil, fmt.Errorf("wrong number of trees (%d) for number of class (%d)", nTrees, e.nClasses)
315+
} else if nTrees%e.nRawOutputGroups != 0 {
316+
return nil, fmt.Errorf("wrong number of trees (%d) for number of class (%d)", nTrees, e.nRawOutputGroups)
317317
}
318318

319319
e.Trees = make([]lgTree, 0, nTrees)
@@ -584,14 +584,14 @@ func LGEnsembleFromJSON(reader io.Reader, loadTransformation bool) (*Ensemble, e
584584
} else if data.NumTreesPerIteration < 1 {
585585
return nil, fmt.Errorf("num_tree_per_iteration (%d) should be > 0", data.NumTreesPerIteration)
586586
}
587-
e.nClasses = data.NumClasses
587+
e.nRawOutputGroups = data.NumClasses
588588
e.MaxFeatureIdx = data.MaxFeatureIdx
589589

590590
nTrees := len(data.Trees)
591591
if nTrees == 0 {
592592
return nil, fmt.Errorf("no trees in file (based on tree_sizes value)")
593-
} else if nTrees%e.nClasses != 0 {
594-
return nil, fmt.Errorf("wrong number of trees (%d) for number of class (%d)", nTrees, e.nClasses)
593+
} else if nTrees%e.nRawOutputGroups != 0 {
594+
return nil, fmt.Errorf("wrong number of trees (%d) for number of class (%d)", nTrees, e.nRawOutputGroups)
595595
}
596596

597597
e.Trees = make([]lgTree, 0, nTrees)

lgensemble_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ func TestLGEnsembleJSON1tree1leaf(t *testing.T) {
242242
t.Fatalf("expected 1 trees (got %d)", model.NEstimators())
243243
}
244244

245-
if model.NClasses() != 1 {
246-
t.Fatalf("expected 1 class (got %d)", model.NClasses())
245+
if model.NRawOutputGroups() != 1 {
246+
t.Fatalf("expected 1 class (got %d)", model.NRawOutputGroups())
247247
}
248248

249249
if model.NFeatures() != 41 {
@@ -273,8 +273,8 @@ func TestLGEnsembleJSON1tree(t *testing.T) {
273273
t.Fatalf("expected 1 trees (got %d)", model.NEstimators())
274274
}
275275

276-
if model.NClasses() != 1 {
277-
t.Fatalf("expected 1 class (got %d)", model.NClasses())
276+
if model.NRawOutputGroups() != 1 {
277+
t.Fatalf("expected 1 class (got %d)", model.NRawOutputGroups())
278278
}
279279

280280
if model.NFeatures() != 2 {

0 commit comments

Comments
 (0)