Skip to content

Commit 0506b17

Browse files
committed
[+] lightgbm: support reading model from JSON format
1 parent 0bea11e commit 0506b17

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

lgensemble_io.go

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,52 @@ package leaves
22

33
import (
44
"bufio"
5+
"encoding/json"
56
"fmt"
7+
"io"
68
"os"
9+
"strconv"
710
"strings"
811

912
"github.com/dmitryikh/leaves/util"
1013
)
1114

15+
type lgEnsembleJSON struct {
16+
Name string `json:"name"`
17+
Version string `json:"version"`
18+
NumClasses int `json:"num_class"`
19+
NumTreesPerIteration int `json:"num_tree_per_iteration"`
20+
MaxFeatureIdx int `json:"max_feature_idx"`
21+
Trees []json.RawMessage `json:"tree_info"`
22+
// TODO: lightgbm should support the next fields
23+
// AverageOutput bool `json:"average_output"`
24+
// Objective string `json:"objective"`
25+
}
26+
27+
type lgTreeJSON struct {
28+
NumLeaves int `json:"num_leaves"`
29+
NumCat uint32 `json:"num_cat"`
30+
// Unused fields:
31+
// TreeIndex uint32 `json:"tree_index"`
32+
// Shrinkage float64 `json:"shrinkage"`
33+
RootRaw json.RawMessage `json:"tree_structure"`
34+
Root interface{}
35+
}
36+
37+
type lgNodeJSON struct {
38+
SplitIndex uint32 `json:"split_index"`
39+
SplitFeature uint32 `json:"split_feature"`
40+
// Threshold could be float64 (for numerical decision) or string (for categorical, example "10||100||400")
41+
Threshold interface{} `json:"threshold"`
42+
DecisionType string `json:"decision_type"`
43+
DefaultLeft bool `json:"default_left"`
44+
MissingType string `json:"missing_type"`
45+
LeftChildRaw json.RawMessage `json:"left_child"`
46+
RightChildRaw json.RawMessage `json:"right_child"`
47+
LeftChild interface{}
48+
RightChild interface{}
49+
}
50+
1251
func convertMissingType(decisionType uint32) (uint8, error) {
1352
missingTypeOrig := (decisionType >> 2) & 3
1453
missingType := uint8(0)
@@ -24,6 +63,12 @@ func convertMissingType(decisionType uint32) (uint8, error) {
2463
return missingType, nil
2564
}
2665

66+
var stringToMissingType = map[string]uint8{
67+
"None": 0,
68+
"Zero": missingZero,
69+
"NaN": missingNan,
70+
}
71+
2772
func lgTreeFromReader(reader *bufio.Reader) (lgTree, error) {
2873
t := lgTree{}
2974
params, err := util.ReadParamsUntilBlank(reader)
@@ -258,6 +303,7 @@ func LGEnsembleFromReader(reader *bufio.Reader) (*Ensemble, error) {
258303
}
259304
treeSizes := strings.Split(treeSizesStr, " ")
260305

306+
// NOTE: we rely on the fact that size of tree_sizes data is equal to number of trees
261307
nTrees := len(treeSizes)
262308
if nTrees == 0 {
263309
return nil, fmt.Errorf("no trees in file (based on tree_sizes value)")
@@ -286,3 +332,265 @@ func LGEnsembleFromFile(filename string) (*Ensemble, error) {
286332
bufReader := bufio.NewReader(reader)
287333
return LGEnsembleFromReader(bufReader)
288334
}
335+
336+
// unmarshalNode recuirsively unmarshal nodes data in the tree from JSON raw data. Tree's node can be:
337+
// 1. leaf node (contains field 'field_value')
338+
// 2. node with decision rule (contains field from `lgNodeJSON` structure)
339+
func unmarshalNode(raw []byte) (interface{}, error) {
340+
node := &lgNodeJSON{}
341+
err := json.Unmarshal(raw, node)
342+
if err != nil {
343+
return nil, err
344+
}
345+
346+
// dirty way to check that we really load a lgNodeJSON struct from raw data
347+
if node.MissingType == "" {
348+
// this is no tree node structure, then it should be map with "leaf_value" record
349+
data := make(map[string]interface{})
350+
err = json.Unmarshal(raw, &data)
351+
if err != nil {
352+
return nil, err
353+
}
354+
value, ok := data["leaf_value"].(float64)
355+
if !ok {
356+
return nil, fmt.Errorf("unknown tree")
357+
}
358+
return value, nil
359+
}
360+
node.LeftChild, err = unmarshalNode(node.LeftChildRaw)
361+
if err != nil {
362+
return nil, err
363+
}
364+
node.RightChild, err = unmarshalNode(node.RightChildRaw)
365+
if err != nil {
366+
return nil, err
367+
}
368+
return node, nil
369+
}
370+
371+
// unmarshalTree unmarshal tree data from JSON raw data and convert it to `lgTree` structure
372+
func unmarshalTree(raw []byte) (lgTree, error) {
373+
t := lgTree{}
374+
375+
treeJSON := &lgTreeJSON{}
376+
err := json.Unmarshal(raw, treeJSON)
377+
if err != nil {
378+
return t, err
379+
}
380+
381+
t.nCategorical = treeJSON.NumCat
382+
if t.nCategorical > 0 {
383+
// first element set to zero for consistency
384+
t.catBoundaries = make([]uint32, 1)
385+
}
386+
387+
if treeJSON.NumLeaves < 1 {
388+
return t, fmt.Errorf("num_leaves < 1")
389+
}
390+
numNodes := treeJSON.NumLeaves - 1
391+
392+
treeJSON.Root, err = unmarshalNode(treeJSON.RootRaw)
393+
if err != nil {
394+
return t, err
395+
}
396+
397+
if value, ok := treeJSON.Root.(float64); ok {
398+
// special case - constant value tree
399+
t.leafValues = append(t.leafValues, value)
400+
return t, nil
401+
}
402+
403+
createNumericalNode := func(nodeJSON *lgNodeJSON) (lgNode, error) {
404+
node := lgNode{}
405+
missingType, isFound := stringToMissingType[nodeJSON.MissingType]
406+
if !isFound {
407+
return node, fmt.Errorf("unknown missing_type '%s'", nodeJSON.MissingType)
408+
}
409+
defaultType := uint8(0)
410+
if nodeJSON.DefaultLeft {
411+
defaultType = defaultLeft
412+
}
413+
threshold, ok := nodeJSON.Threshold.(float64)
414+
if !ok {
415+
return node, fmt.Errorf("unexpected Threshold type %T", nodeJSON.Threshold)
416+
}
417+
node = numericalNode(nodeJSON.SplitFeature, missingType, threshold, defaultType)
418+
if value, ok := nodeJSON.LeftChild.(float64); ok {
419+
node.Flags |= leftLeaf
420+
node.Left = uint32(len(t.leafValues))
421+
t.leafValues = append(t.leafValues, value)
422+
}
423+
if value, ok := nodeJSON.RightChild.(float64); ok {
424+
node.Flags |= rightLeaf
425+
node.Right = uint32(len(t.leafValues))
426+
t.leafValues = append(t.leafValues, value)
427+
}
428+
return node, nil
429+
}
430+
431+
createCategoricalNode := func(nodeJSON *lgNodeJSON) (lgNode, error) {
432+
node := lgNode{}
433+
missingType, isFound := stringToMissingType[nodeJSON.MissingType]
434+
if !isFound {
435+
return node, fmt.Errorf("unknown missing_type '%s'", nodeJSON.MissingType)
436+
}
437+
438+
thresholdString, ok := nodeJSON.Threshold.(string)
439+
if !ok {
440+
return node, fmt.Errorf("unexpected Threshold type %T", nodeJSON.Threshold)
441+
}
442+
tokens := strings.Split(thresholdString, "||")
443+
444+
nBits := len(tokens)
445+
catIdx := uint32(0)
446+
catType := uint8(0)
447+
if nBits == 0 {
448+
return node, fmt.Errorf("no bits set")
449+
} else if nBits == 1 {
450+
value, err := strconv.Atoi(tokens[0])
451+
if err != nil {
452+
return node, fmt.Errorf("can't convert %s: %s", tokens[0], err.Error())
453+
}
454+
catIdx = uint32(value)
455+
catType = catOneHot
456+
} else {
457+
thresholdValues := make([]int, len(tokens))
458+
for i, valueStr := range tokens {
459+
value, err := strconv.Atoi(valueStr)
460+
if err != nil {
461+
return node, fmt.Errorf("can't convert %s: %s", valueStr, err.Error())
462+
}
463+
thresholdValues[i] = value
464+
}
465+
466+
bitset := util.ConstructBitset(thresholdValues)
467+
if len(bitset) == 1 {
468+
catIdx = bitset[0]
469+
catType = catSmall
470+
} else {
471+
// regular case with large bitset
472+
catIdx = uint32(len(t.catBoundaries) - 1)
473+
t.catThresholds = append(t.catThresholds, bitset...)
474+
t.catBoundaries = append(t.catBoundaries, uint32(len(t.catThresholds)))
475+
}
476+
}
477+
478+
node = categoricalNode(nodeJSON.SplitFeature, missingType, catIdx, catType)
479+
if value, ok := nodeJSON.LeftChild.(float64); ok {
480+
node.Flags |= leftLeaf
481+
node.Left = uint32(len(t.leafValues))
482+
t.leafValues = append(t.leafValues, value)
483+
}
484+
if value, ok := nodeJSON.RightChild.(float64); ok {
485+
node.Flags |= rightLeaf
486+
node.Right = uint32(len(t.leafValues))
487+
t.leafValues = append(t.leafValues, value)
488+
}
489+
return node, nil
490+
}
491+
createNode := func(nodeJSON *lgNodeJSON) (lgNode, error) {
492+
if nodeJSON.DecisionType == "==" {
493+
return createCategoricalNode(nodeJSON)
494+
} else if nodeJSON.DecisionType == "<=" {
495+
return createNumericalNode(nodeJSON)
496+
} else {
497+
return lgNode{}, fmt.Errorf("unknown decision type '%s'", nodeJSON.DecisionType)
498+
}
499+
}
500+
501+
type StackData struct {
502+
// pointer to parent's Left/RightChild field
503+
parentPtr *uint32
504+
nodeJSON *lgNodeJSON
505+
}
506+
stack := make([]StackData, 0, numNodes)
507+
if root, ok := treeJSON.Root.(*lgNodeJSON); ok {
508+
stack = append(stack, StackData{nil, root})
509+
} else {
510+
return t, fmt.Errorf("unexpected type of Root: %T", treeJSON.Root)
511+
}
512+
// NOTE: we rely on fact that t.nodes won't be reallocated (`parentPtr` points to its data)
513+
t.nodes = make([]lgNode, 0, numNodes)
514+
515+
for len(stack) > 0 {
516+
stackData := stack[len(stack)-1]
517+
stack = stack[:len(stack)-1]
518+
node, err := createNode(stackData.nodeJSON)
519+
if err != nil {
520+
return t, err
521+
}
522+
if stackData.parentPtr != nil {
523+
*stackData.parentPtr = uint32(len(t.nodes))
524+
}
525+
t.nodes = append(t.nodes, node)
526+
if node.Flags&leftLeaf == 0 {
527+
if left, ok := stackData.nodeJSON.LeftChild.(*lgNodeJSON); ok {
528+
stack = append(stack, StackData{&t.nodes[len(t.nodes)-1].Left, left})
529+
} else if _, ok := stackData.nodeJSON.LeftChild.(float64); ok {
530+
} else {
531+
return t, fmt.Errorf("unexpected left child type %T", stackData.nodeJSON.LeftChild)
532+
}
533+
}
534+
if node.Flags&rightLeaf == 0 {
535+
if right, ok := stackData.nodeJSON.RightChild.(*lgNodeJSON); ok {
536+
stack = append(stack, StackData{&t.nodes[len(t.nodes)-1].Right, right})
537+
} else if _, ok := stackData.nodeJSON.RightChild.(float64); ok {
538+
} else {
539+
return t, fmt.Errorf("unexpected right child type %T", stackData.nodeJSON.RightChild)
540+
}
541+
}
542+
}
543+
return t, nil
544+
}
545+
546+
// LGEnsembleFromJSON reads LightGBM model from stream with JSON data
547+
func LGEnsembleFromJSON(reader io.Reader) (*Ensemble, error) {
548+
data := &lgEnsembleJSON{}
549+
dec := json.NewDecoder(reader)
550+
551+
err := dec.Decode(data)
552+
if err != nil {
553+
return nil, err
554+
}
555+
556+
e := &lgEnsemble{name: "lightgbm.gbdt"}
557+
558+
if data.Name != "tree" {
559+
return nil, fmt.Errorf("expected 'name' field = 'tree' (got: '%s')", data.Name)
560+
}
561+
562+
if data.Version != "v2" {
563+
return nil, fmt.Errorf("expected 'version' field = 'v2' (got: '%s')", data.Version)
564+
}
565+
566+
if data.NumClasses != data.NumTreesPerIteration {
567+
return nil, fmt.Errorf(
568+
"meet case when num_class (%d) != num_tree_per_iteration (%d)",
569+
data.NumClasses,
570+
data.NumTreesPerIteration,
571+
)
572+
} else if data.NumClasses < 1 {
573+
return nil, fmt.Errorf("num_class (%d) should be > 0", data.NumClasses)
574+
} else if data.NumTreesPerIteration < 1 {
575+
return nil, fmt.Errorf("num_tree_per_iteration (%d) should be > 0", data.NumTreesPerIteration)
576+
}
577+
e.nClasses = data.NumClasses
578+
e.MaxFeatureIdx = data.MaxFeatureIdx
579+
580+
nTrees := len(data.Trees)
581+
if nTrees == 0 {
582+
return nil, fmt.Errorf("no trees in file (based on tree_sizes value)")
583+
} else if nTrees%e.nClasses != 0 {
584+
return nil, fmt.Errorf("wrong number of trees (%d) for number of class (%d)", nTrees, e.nClasses)
585+
}
586+
587+
e.Trees = make([]lgTree, 0, nTrees)
588+
for i := 0; i < nTrees; i++ {
589+
tree, err := unmarshalTree(data.Trees[i])
590+
if err != nil {
591+
return nil, fmt.Errorf("error while reading %d tree: %s", i, err.Error())
592+
}
593+
e.Trees = append(e.Trees, tree)
594+
}
595+
return &Ensemble{e}, nil
596+
}

util/util.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,33 @@ func NumberOfSetBits(bitset []uint32) uint32 {
182182
return count
183183
}
184184

185+
// ConstructBitset return a slice where bits in place of `values` are set
186+
func ConstructBitset(values []int) []uint32 {
187+
if len(values) == 0 {
188+
return nil
189+
}
190+
max := values[0]
191+
for _, v := range values {
192+
if v > max {
193+
max = v
194+
}
195+
}
196+
max++
197+
198+
nBitsetValues := max / 32
199+
if max%32 != 0 {
200+
nBitsetValues++
201+
}
202+
203+
bitset := make([]uint32, nBitsetValues)
204+
for _, v := range values {
205+
i1 := v / 32
206+
i2 := v % 32
207+
bitset[i1] |= 1 << uint32(i2)
208+
}
209+
return bitset
210+
}
211+
185212
func AlmostEqualFloat64(a, b, threshold float64) bool {
186213
return math.Abs(a-b) <= threshold
187214
}

0 commit comments

Comments
 (0)