@@ -2,13 +2,52 @@ package leaves
2
2
3
3
import (
4
4
"bufio"
5
+ "encoding/json"
5
6
"fmt"
7
+ "io"
6
8
"os"
9
+ "strconv"
7
10
"strings"
8
11
9
12
"github.com/dmitryikh/leaves/util"
10
13
)
11
14
15
+ type lgEnsembleJSON struct {
16
+ Name string `json:"name"`
17
+ Version string `json:"version"`
18
+ NumClasses int `json:"num_class"`
19
+ NumTreesPerIteration int `json:"num_tree_per_iteration"`
20
+ MaxFeatureIdx int `json:"max_feature_idx"`
21
+ Trees []json.RawMessage `json:"tree_info"`
22
+ // TODO: lightgbm should support the next fields
23
+ // AverageOutput bool `json:"average_output"`
24
+ // Objective string `json:"objective"`
25
+ }
26
+
27
+ type lgTreeJSON struct {
28
+ NumLeaves int `json:"num_leaves"`
29
+ NumCat uint32 `json:"num_cat"`
30
+ // Unused fields:
31
+ // TreeIndex uint32 `json:"tree_index"`
32
+ // Shrinkage float64 `json:"shrinkage"`
33
+ RootRaw json.RawMessage `json:"tree_structure"`
34
+ Root interface {}
35
+ }
36
+
37
+ type lgNodeJSON struct {
38
+ SplitIndex uint32 `json:"split_index"`
39
+ SplitFeature uint32 `json:"split_feature"`
40
+ // Threshold could be float64 (for numerical decision) or string (for categorical, example "10||100||400")
41
+ Threshold interface {} `json:"threshold"`
42
+ DecisionType string `json:"decision_type"`
43
+ DefaultLeft bool `json:"default_left"`
44
+ MissingType string `json:"missing_type"`
45
+ LeftChildRaw json.RawMessage `json:"left_child"`
46
+ RightChildRaw json.RawMessage `json:"right_child"`
47
+ LeftChild interface {}
48
+ RightChild interface {}
49
+ }
50
+
12
51
func convertMissingType (decisionType uint32 ) (uint8 , error ) {
13
52
missingTypeOrig := (decisionType >> 2 ) & 3
14
53
missingType := uint8 (0 )
@@ -24,6 +63,12 @@ func convertMissingType(decisionType uint32) (uint8, error) {
24
63
return missingType , nil
25
64
}
26
65
66
+ var stringToMissingType = map [string ]uint8 {
67
+ "None" : 0 ,
68
+ "Zero" : missingZero ,
69
+ "NaN" : missingNan ,
70
+ }
71
+
27
72
func lgTreeFromReader (reader * bufio.Reader ) (lgTree , error ) {
28
73
t := lgTree {}
29
74
params , err := util .ReadParamsUntilBlank (reader )
@@ -258,6 +303,7 @@ func LGEnsembleFromReader(reader *bufio.Reader) (*Ensemble, error) {
258
303
}
259
304
treeSizes := strings .Split (treeSizesStr , " " )
260
305
306
+ // NOTE: we rely on the fact that size of tree_sizes data is equal to number of trees
261
307
nTrees := len (treeSizes )
262
308
if nTrees == 0 {
263
309
return nil , fmt .Errorf ("no trees in file (based on tree_sizes value)" )
@@ -286,3 +332,265 @@ func LGEnsembleFromFile(filename string) (*Ensemble, error) {
286
332
bufReader := bufio .NewReader (reader )
287
333
return LGEnsembleFromReader (bufReader )
288
334
}
335
+
336
+ // unmarshalNode recuirsively unmarshal nodes data in the tree from JSON raw data. Tree's node can be:
337
+ // 1. leaf node (contains field 'field_value')
338
+ // 2. node with decision rule (contains field from `lgNodeJSON` structure)
339
+ func unmarshalNode (raw []byte ) (interface {}, error ) {
340
+ node := & lgNodeJSON {}
341
+ err := json .Unmarshal (raw , node )
342
+ if err != nil {
343
+ return nil , err
344
+ }
345
+
346
+ // dirty way to check that we really load a lgNodeJSON struct from raw data
347
+ if node .MissingType == "" {
348
+ // this is no tree node structure, then it should be map with "leaf_value" record
349
+ data := make (map [string ]interface {})
350
+ err = json .Unmarshal (raw , & data )
351
+ if err != nil {
352
+ return nil , err
353
+ }
354
+ value , ok := data ["leaf_value" ].(float64 )
355
+ if ! ok {
356
+ return nil , fmt .Errorf ("unknown tree" )
357
+ }
358
+ return value , nil
359
+ }
360
+ node .LeftChild , err = unmarshalNode (node .LeftChildRaw )
361
+ if err != nil {
362
+ return nil , err
363
+ }
364
+ node .RightChild , err = unmarshalNode (node .RightChildRaw )
365
+ if err != nil {
366
+ return nil , err
367
+ }
368
+ return node , nil
369
+ }
370
+
371
+ // unmarshalTree unmarshal tree data from JSON raw data and convert it to `lgTree` structure
372
+ func unmarshalTree (raw []byte ) (lgTree , error ) {
373
+ t := lgTree {}
374
+
375
+ treeJSON := & lgTreeJSON {}
376
+ err := json .Unmarshal (raw , treeJSON )
377
+ if err != nil {
378
+ return t , err
379
+ }
380
+
381
+ t .nCategorical = treeJSON .NumCat
382
+ if t .nCategorical > 0 {
383
+ // first element set to zero for consistency
384
+ t .catBoundaries = make ([]uint32 , 1 )
385
+ }
386
+
387
+ if treeJSON .NumLeaves < 1 {
388
+ return t , fmt .Errorf ("num_leaves < 1" )
389
+ }
390
+ numNodes := treeJSON .NumLeaves - 1
391
+
392
+ treeJSON .Root , err = unmarshalNode (treeJSON .RootRaw )
393
+ if err != nil {
394
+ return t , err
395
+ }
396
+
397
+ if value , ok := treeJSON .Root .(float64 ); ok {
398
+ // special case - constant value tree
399
+ t .leafValues = append (t .leafValues , value )
400
+ return t , nil
401
+ }
402
+
403
+ createNumericalNode := func (nodeJSON * lgNodeJSON ) (lgNode , error ) {
404
+ node := lgNode {}
405
+ missingType , isFound := stringToMissingType [nodeJSON .MissingType ]
406
+ if ! isFound {
407
+ return node , fmt .Errorf ("unknown missing_type '%s'" , nodeJSON .MissingType )
408
+ }
409
+ defaultType := uint8 (0 )
410
+ if nodeJSON .DefaultLeft {
411
+ defaultType = defaultLeft
412
+ }
413
+ threshold , ok := nodeJSON .Threshold .(float64 )
414
+ if ! ok {
415
+ return node , fmt .Errorf ("unexpected Threshold type %T" , nodeJSON .Threshold )
416
+ }
417
+ node = numericalNode (nodeJSON .SplitFeature , missingType , threshold , defaultType )
418
+ if value , ok := nodeJSON .LeftChild .(float64 ); ok {
419
+ node .Flags |= leftLeaf
420
+ node .Left = uint32 (len (t .leafValues ))
421
+ t .leafValues = append (t .leafValues , value )
422
+ }
423
+ if value , ok := nodeJSON .RightChild .(float64 ); ok {
424
+ node .Flags |= rightLeaf
425
+ node .Right = uint32 (len (t .leafValues ))
426
+ t .leafValues = append (t .leafValues , value )
427
+ }
428
+ return node , nil
429
+ }
430
+
431
+ createCategoricalNode := func (nodeJSON * lgNodeJSON ) (lgNode , error ) {
432
+ node := lgNode {}
433
+ missingType , isFound := stringToMissingType [nodeJSON .MissingType ]
434
+ if ! isFound {
435
+ return node , fmt .Errorf ("unknown missing_type '%s'" , nodeJSON .MissingType )
436
+ }
437
+
438
+ thresholdString , ok := nodeJSON .Threshold .(string )
439
+ if ! ok {
440
+ return node , fmt .Errorf ("unexpected Threshold type %T" , nodeJSON .Threshold )
441
+ }
442
+ tokens := strings .Split (thresholdString , "||" )
443
+
444
+ nBits := len (tokens )
445
+ catIdx := uint32 (0 )
446
+ catType := uint8 (0 )
447
+ if nBits == 0 {
448
+ return node , fmt .Errorf ("no bits set" )
449
+ } else if nBits == 1 {
450
+ value , err := strconv .Atoi (tokens [0 ])
451
+ if err != nil {
452
+ return node , fmt .Errorf ("can't convert %s: %s" , tokens [0 ], err .Error ())
453
+ }
454
+ catIdx = uint32 (value )
455
+ catType = catOneHot
456
+ } else {
457
+ thresholdValues := make ([]int , len (tokens ))
458
+ for i , valueStr := range tokens {
459
+ value , err := strconv .Atoi (valueStr )
460
+ if err != nil {
461
+ return node , fmt .Errorf ("can't convert %s: %s" , valueStr , err .Error ())
462
+ }
463
+ thresholdValues [i ] = value
464
+ }
465
+
466
+ bitset := util .ConstructBitset (thresholdValues )
467
+ if len (bitset ) == 1 {
468
+ catIdx = bitset [0 ]
469
+ catType = catSmall
470
+ } else {
471
+ // regular case with large bitset
472
+ catIdx = uint32 (len (t .catBoundaries ) - 1 )
473
+ t .catThresholds = append (t .catThresholds , bitset ... )
474
+ t .catBoundaries = append (t .catBoundaries , uint32 (len (t .catThresholds )))
475
+ }
476
+ }
477
+
478
+ node = categoricalNode (nodeJSON .SplitFeature , missingType , catIdx , catType )
479
+ if value , ok := nodeJSON .LeftChild .(float64 ); ok {
480
+ node .Flags |= leftLeaf
481
+ node .Left = uint32 (len (t .leafValues ))
482
+ t .leafValues = append (t .leafValues , value )
483
+ }
484
+ if value , ok := nodeJSON .RightChild .(float64 ); ok {
485
+ node .Flags |= rightLeaf
486
+ node .Right = uint32 (len (t .leafValues ))
487
+ t .leafValues = append (t .leafValues , value )
488
+ }
489
+ return node , nil
490
+ }
491
+ createNode := func (nodeJSON * lgNodeJSON ) (lgNode , error ) {
492
+ if nodeJSON .DecisionType == "==" {
493
+ return createCategoricalNode (nodeJSON )
494
+ } else if nodeJSON .DecisionType == "<=" {
495
+ return createNumericalNode (nodeJSON )
496
+ } else {
497
+ return lgNode {}, fmt .Errorf ("unknown decision type '%s'" , nodeJSON .DecisionType )
498
+ }
499
+ }
500
+
501
+ type StackData struct {
502
+ // pointer to parent's Left/RightChild field
503
+ parentPtr * uint32
504
+ nodeJSON * lgNodeJSON
505
+ }
506
+ stack := make ([]StackData , 0 , numNodes )
507
+ if root , ok := treeJSON .Root .(* lgNodeJSON ); ok {
508
+ stack = append (stack , StackData {nil , root })
509
+ } else {
510
+ return t , fmt .Errorf ("unexpected type of Root: %T" , treeJSON .Root )
511
+ }
512
+ // NOTE: we rely on fact that t.nodes won't be reallocated (`parentPtr` points to its data)
513
+ t .nodes = make ([]lgNode , 0 , numNodes )
514
+
515
+ for len (stack ) > 0 {
516
+ stackData := stack [len (stack )- 1 ]
517
+ stack = stack [:len (stack )- 1 ]
518
+ node , err := createNode (stackData .nodeJSON )
519
+ if err != nil {
520
+ return t , err
521
+ }
522
+ if stackData .parentPtr != nil {
523
+ * stackData .parentPtr = uint32 (len (t .nodes ))
524
+ }
525
+ t .nodes = append (t .nodes , node )
526
+ if node .Flags & leftLeaf == 0 {
527
+ if left , ok := stackData .nodeJSON .LeftChild .(* lgNodeJSON ); ok {
528
+ stack = append (stack , StackData {& t .nodes [len (t .nodes )- 1 ].Left , left })
529
+ } else if _ , ok := stackData .nodeJSON .LeftChild .(float64 ); ok {
530
+ } else {
531
+ return t , fmt .Errorf ("unexpected left child type %T" , stackData .nodeJSON .LeftChild )
532
+ }
533
+ }
534
+ if node .Flags & rightLeaf == 0 {
535
+ if right , ok := stackData .nodeJSON .RightChild .(* lgNodeJSON ); ok {
536
+ stack = append (stack , StackData {& t .nodes [len (t .nodes )- 1 ].Right , right })
537
+ } else if _ , ok := stackData .nodeJSON .RightChild .(float64 ); ok {
538
+ } else {
539
+ return t , fmt .Errorf ("unexpected right child type %T" , stackData .nodeJSON .RightChild )
540
+ }
541
+ }
542
+ }
543
+ return t , nil
544
+ }
545
+
546
+ // LGEnsembleFromJSON reads LightGBM model from stream with JSON data
547
+ func LGEnsembleFromJSON (reader io.Reader ) (* Ensemble , error ) {
548
+ data := & lgEnsembleJSON {}
549
+ dec := json .NewDecoder (reader )
550
+
551
+ err := dec .Decode (data )
552
+ if err != nil {
553
+ return nil , err
554
+ }
555
+
556
+ e := & lgEnsemble {name : "lightgbm.gbdt" }
557
+
558
+ if data .Name != "tree" {
559
+ return nil , fmt .Errorf ("expected 'name' field = 'tree' (got: '%s')" , data .Name )
560
+ }
561
+
562
+ if data .Version != "v2" {
563
+ return nil , fmt .Errorf ("expected 'version' field = 'v2' (got: '%s')" , data .Version )
564
+ }
565
+
566
+ if data .NumClasses != data .NumTreesPerIteration {
567
+ return nil , fmt .Errorf (
568
+ "meet case when num_class (%d) != num_tree_per_iteration (%d)" ,
569
+ data .NumClasses ,
570
+ data .NumTreesPerIteration ,
571
+ )
572
+ } else if data .NumClasses < 1 {
573
+ return nil , fmt .Errorf ("num_class (%d) should be > 0" , data .NumClasses )
574
+ } else if data .NumTreesPerIteration < 1 {
575
+ return nil , fmt .Errorf ("num_tree_per_iteration (%d) should be > 0" , data .NumTreesPerIteration )
576
+ }
577
+ e .nClasses = data .NumClasses
578
+ e .MaxFeatureIdx = data .MaxFeatureIdx
579
+
580
+ nTrees := len (data .Trees )
581
+ if nTrees == 0 {
582
+ return nil , fmt .Errorf ("no trees in file (based on tree_sizes value)" )
583
+ } else if nTrees % e .nClasses != 0 {
584
+ return nil , fmt .Errorf ("wrong number of trees (%d) for number of class (%d)" , nTrees , e .nClasses )
585
+ }
586
+
587
+ e .Trees = make ([]lgTree , 0 , nTrees )
588
+ for i := 0 ; i < nTrees ; i ++ {
589
+ tree , err := unmarshalTree (data .Trees [i ])
590
+ if err != nil {
591
+ return nil , fmt .Errorf ("error while reading %d tree: %s" , i , err .Error ())
592
+ }
593
+ e .Trees = append (e .Trees , tree )
594
+ }
595
+ return & Ensemble {e }, nil
596
+ }
0 commit comments