@@ -739,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
739
739
context->input (" bias_feature_id" , &bias_feature_id_t ));
740
740
int64 bias_feature_id = bias_feature_id_t ->scalar <int64>()();
741
741
742
+ const Tensor* weak_learner_type_t ;
743
+ OP_REQUIRES_OK (context,
744
+ context->input (" weak_learner_type" , &weak_learner_type_t ));
745
+ const int32 weak_learner_type = weak_learner_type_t ->scalar <int32>()();
746
+
742
747
// Find the number of unique partitions before we allocate the output.
743
748
std::vector<int32> partition_boundaries;
744
749
std::vector<int32> non_empty_partitions;
@@ -767,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
767
772
tensorflow::TTypes<int32>::Vec output_partition_ids =
768
773
output_partition_ids_t ->vec <int32>();
769
774
775
+ // For a normal tree, we output a split per partition. For an oblivious
776
+ // tree, we output one split for all partitions of the layer.
777
+ int size_output = num_elements;
778
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
779
+ num_elements > 0 ) {
780
+ size_output = 1 ;
781
+ }
782
+
770
783
Tensor* gains_t = nullptr ;
771
- OP_REQUIRES_OK (
772
- context, context->allocate_output (" gains" , TensorShape ({num_elements}),
773
- &gains_t ));
784
+ OP_REQUIRES_OK (context, context->allocate_output (
785
+ " gains" , TensorShape ({size_output}), &gains_t ));
774
786
775
787
tensorflow::TTypes<float >::Vec gains = gains_t ->vec <float >();
776
788
777
789
Tensor* output_splits_t = nullptr ;
778
- OP_REQUIRES_OK (context, context->allocate_output (
779
- " split_infos " , TensorShape ({num_elements }),
780
- &output_splits_t ));
790
+ OP_REQUIRES_OK (context, context->allocate_output (" split_infos " ,
791
+ TensorShape ({size_output }),
792
+ &output_splits_t ));
781
793
tensorflow::TTypes<string>::Vec output_splits =
782
794
output_splits_t ->vec <string>();
795
+ if (num_elements == 0 ) {
796
+ return ;
797
+ }
783
798
SplitBuilderState state (context);
799
+ switch (weak_learner_type) {
800
+ case LearnerConfig::NORMAL_DECISION_TREE: {
801
+ ComputeNormalDecisionTree (
802
+ context, &state, normalizer_ratio, num_elements,
803
+ partition_boundaries, non_empty_partitions, bias_feature_id,
804
+ partition_ids, feature_ids, gradients_t , hessians_t ,
805
+ &output_partition_ids, &gains, &output_splits);
806
+ break ;
807
+ }
808
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
809
+ ComputeObliviousDecisionTree (
810
+ context, &state, normalizer_ratio, num_elements,
811
+ partition_boundaries, non_empty_partitions, bias_feature_id,
812
+ partition_ids, feature_ids, gradients_t , hessians_t ,
813
+ &output_partition_ids, &gains, &output_splits);
814
+ break ;
815
+ }
816
+ }
817
+ }
818
+
819
+ private:
820
+ void ComputeNormalDecisionTree (
821
+ OpKernelContext* const context, SplitBuilderState* state,
822
+ const float normalizer_ratio, const int num_elements,
823
+ const std::vector<int32>& partition_boundaries,
824
+ const std::vector<int32>& non_empty_partitions,
825
+ const int64 bias_feature_id,
826
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
827
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
828
+ const Tensor* gradients_t , const Tensor* hessians_t ,
829
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
830
+ tensorflow::TTypes<float >::Vec* gains,
831
+ tensorflow::TTypes<string>::Vec* output_splits) {
784
832
for (int root_idx = 0 ; root_idx < num_elements; ++root_idx) {
785
833
float best_gain = std::numeric_limits<float >::lowest ();
786
834
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -790,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
790
838
errors::InvalidArgument (" Bias feature ID missing." ));
791
839
GradientStats root_gradient_stats (*gradients_t , *hessians_t , start_index);
792
840
root_gradient_stats *= normalizer_ratio;
793
- NodeStats root_stats = state. ComputeNodeStats (root_gradient_stats);
841
+ NodeStats root_stats = state-> ComputeNodeStats (root_gradient_stats);
794
842
int32 best_feature_idx = 0 ;
795
843
NodeStats best_right_node_stats (0 );
796
844
NodeStats best_left_node_stats (0 );
@@ -801,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
801
849
left_gradient_stats *= normalizer_ratio;
802
850
GradientStats right_gradient_stats =
803
851
root_gradient_stats - left_gradient_stats;
804
- NodeStats left_stats = state. ComputeNodeStats (left_gradient_stats);
805
- NodeStats right_stats = state. ComputeNodeStats (right_gradient_stats);
852
+ NodeStats left_stats = state-> ComputeNodeStats (left_gradient_stats);
853
+ NodeStats right_stats = state-> ComputeNodeStats (right_gradient_stats);
806
854
if (left_stats.gain + right_stats.gain > best_gain) {
807
855
best_gain = left_stats.gain + right_stats.gain ;
808
856
best_left_node_stats = left_stats;
@@ -813,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
813
861
SplitInfo split_info;
814
862
auto * equality_split = split_info.mutable_split_node ()
815
863
->mutable_categorical_id_binary_split ();
816
- equality_split->set_feature_column (state. feature_column_group_id ());
864
+ equality_split->set_feature_column (state-> feature_column_group_id ());
817
865
equality_split->set_feature_id (feature_ids (best_feature_idx, 0 ));
818
866
auto * left_child = split_info.mutable_left_child ();
819
867
auto * right_child = split_info.mutable_right_child ();
820
- state. FillLeaf (best_left_node_stats, left_child);
821
- state. FillLeaf (best_right_node_stats, right_child);
822
- split_info.SerializeToString (&output_splits (root_idx));
823
- gains (root_idx) =
824
- best_gain - root_stats.gain - state. tree_complexity_regularization ();
825
- output_partition_ids (root_idx) = partition_ids (start_index);
868
+ state-> FillLeaf (best_left_node_stats, left_child);
869
+ state-> FillLeaf (best_right_node_stats, right_child);
870
+ split_info.SerializeToString (&(* output_splits) (root_idx));
871
+ (* gains) (root_idx) =
872
+ best_gain - root_stats.gain - state-> tree_complexity_regularization ();
873
+ (* output_partition_ids) (root_idx) = partition_ids (start_index);
826
874
}
827
875
}
876
+
877
+ void ComputeObliviousDecisionTree (
878
+ OpKernelContext* const context, SplitBuilderState* state,
879
+ const float normalizer_ratio, const int num_elements,
880
+ const std::vector<int32>& partition_boundaries,
881
+ const std::vector<int32>& non_empty_partitions,
882
+ const int64 bias_feature_id,
883
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
884
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
885
+ const Tensor* gradients_t , const Tensor* hessians_t ,
886
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
887
+ tensorflow::TTypes<float >::Vec* gains,
888
+ tensorflow::TTypes<string>::Vec* output_splits) {
889
+ // Holds the root stats per each node to be split.
890
+ std::vector<GradientStats> current_layer_stats;
891
+ current_layer_stats.reserve (num_elements);
892
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
893
+ const int start_index = partition_boundaries[root_idx];
894
+ // First feature ID in each partition should be the bias feature.
895
+ OP_REQUIRES (context, feature_ids (start_index, 0 ) == bias_feature_id,
896
+ errors::InvalidArgument (" Bias feature ID missing." ));
897
+ GradientStats root_gradient_stats (*gradients_t , *hessians_t , start_index);
898
+ root_gradient_stats *= normalizer_ratio;
899
+ current_layer_stats.push_back (root_gradient_stats);
900
+ }
901
+ float best_gain = std::numeric_limits<float >::lowest ();
902
+ int64 best_feature_id = 0 ;
903
+ std::vector<NodeStats> best_right_node_stats (num_elements, NodeStats (0 ));
904
+ std::vector<NodeStats> best_left_node_stats (num_elements, NodeStats (0 ));
905
+ std::vector<NodeStats> current_left_node_stats (num_elements, NodeStats (0 ));
906
+ std::vector<NodeStats> current_right_node_stats (num_elements, NodeStats (0 ));
907
+ int64 current_feature_id = std::numeric_limits<int64>::max ();
908
+ int64 last_feature_id = -1 ;
909
+ // Find the lowest feature id, this is going to be the first feature id to
910
+ // try.
911
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
912
+ const int start_index = partition_boundaries[root_idx];
913
+ if (feature_ids (start_index + 1 , 0 ) < current_feature_id) {
914
+ current_feature_id = feature_ids (start_index + 1 , 0 );
915
+ }
916
+ }
917
+ // Indexes offsets for each of the partitions that can be used to access
918
+ // gradients of a partition for a current feature we consider. Start at one
919
+ // beacuse the zero index is for the bias.
920
+ std::vector<int > current_layer_offsets (num_elements, 1 );
921
+ // The idea is to try every feature id in increasing order. In each
922
+ // iteration we calculate the gain of the layer using the current feature id
923
+ // as split value, and we also obtain the following feature id to try.
924
+ while (current_feature_id > last_feature_id) {
925
+ last_feature_id = current_feature_id;
926
+ int64 next_feature_id = -1 ;
927
+ // Left gradient stats per node.
928
+ std::vector<GradientStats> left_gradient_stats (num_elements);
929
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
930
+ int idx =
931
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
932
+ const int end_index = partition_boundaries[root_idx + 1 ];
933
+ if (idx < end_index && feature_ids (idx, 0 ) == current_feature_id) {
934
+ GradientStats g (*gradients_t , *hessians_t , idx);
935
+ g *= normalizer_ratio;
936
+ left_gradient_stats[root_idx] = g;
937
+ current_layer_offsets[root_idx]++;
938
+ idx++;
939
+ }
940
+ if (idx < end_index &&
941
+ (feature_ids (idx, 0 ) < next_feature_id || next_feature_id == -1 )) {
942
+ next_feature_id = feature_ids (idx, 0 );
943
+ }
944
+ }
945
+ float gain_of_split = 0.0 ;
946
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
947
+ GradientStats right_gradient_stats =
948
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
949
+ NodeStats left_stat =
950
+ state->ComputeNodeStats (left_gradient_stats[root_idx]);
951
+ NodeStats right_stat = state->ComputeNodeStats (right_gradient_stats);
952
+ gain_of_split += left_stat.gain + right_stat.gain ;
953
+ current_left_node_stats[root_idx] = left_stat;
954
+ current_right_node_stats[root_idx] = right_stat;
955
+ }
956
+ if (gain_of_split > best_gain) {
957
+ best_gain = gain_of_split;
958
+ best_left_node_stats = current_left_node_stats;
959
+ best_right_node_stats = current_right_node_stats;
960
+ best_feature_id = current_feature_id;
961
+ }
962
+ current_feature_id = next_feature_id;
963
+ }
964
+
965
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
966
+ best_gain -= state->ComputeNodeStats (current_layer_stats[root_idx]).gain ;
967
+ }
968
+ best_gain -= num_elements * state->tree_complexity_regularization ();
969
+
970
+ ObliviousSplitInfo oblivious_split_info;
971
+ auto * equality_split =
972
+ oblivious_split_info.mutable_split_node ()
973
+ ->mutable_oblivious_categorical_id_binary_split ();
974
+ equality_split->set_feature_column (state->feature_column_group_id ());
975
+ equality_split->set_feature_id (best_feature_id);
976
+ (*gains)(0 ) = best_gain;
977
+
978
+ for (int root_idx = 0 ; root_idx < num_elements; root_idx++) {
979
+ auto * left_child = oblivious_split_info.add_children ();
980
+ auto * right_child = oblivious_split_info.add_children ();
981
+
982
+ state->FillLeaf (best_left_node_stats[root_idx], left_child);
983
+ state->FillLeaf (best_right_node_stats[root_idx], right_child);
984
+
985
+ const int start_index = partition_boundaries[root_idx];
986
+ (*output_partition_ids)(root_idx) = partition_ids (start_index);
987
+ oblivious_split_info.add_children_parent_id (partition_ids (start_index));
988
+ }
989
+ oblivious_split_info.SerializeToString (&(*output_splits)(0 ));
990
+ }
828
991
};
829
992
830
993
REGISTER_KERNEL_BUILDER (
0 commit comments