Skip to content

Commit c4099e6

Browse files
Added support for categorical features.
Ops are now interconnected to support oblivious decision trees. PiperOrigin-RevId: 210642692
1 parent 6eabd59 commit c4099e6

File tree

9 files changed

+504
-21
lines changed

9 files changed

+504
-21
lines changed

tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc

Lines changed: 179 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
739739
context->input("bias_feature_id", &bias_feature_id_t));
740740
int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
741741

742+
const Tensor* weak_learner_type_t;
743+
OP_REQUIRES_OK(context,
744+
context->input("weak_learner_type", &weak_learner_type_t));
745+
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
746+
742747
// Find the number of unique partitions before we allocate the output.
743748
std::vector<int32> partition_boundaries;
744749
std::vector<int32> non_empty_partitions;
@@ -767,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
767772
tensorflow::TTypes<int32>::Vec output_partition_ids =
768773
output_partition_ids_t->vec<int32>();
769774

775+
// For a normal tree, we output a split per partition. For an oblivious
776+
// tree, we output one split for all partitions of the layer.
777+
int size_output = num_elements;
778+
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
779+
num_elements > 0) {
780+
size_output = 1;
781+
}
782+
770783
Tensor* gains_t = nullptr;
771-
OP_REQUIRES_OK(
772-
context, context->allocate_output("gains", TensorShape({num_elements}),
773-
&gains_t));
784+
OP_REQUIRES_OK(context, context->allocate_output(
785+
"gains", TensorShape({size_output}), &gains_t));
774786

775787
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
776788

777789
Tensor* output_splits_t = nullptr;
778-
OP_REQUIRES_OK(context, context->allocate_output(
779-
"split_infos", TensorShape({num_elements}),
780-
&output_splits_t));
790+
OP_REQUIRES_OK(context, context->allocate_output("split_infos",
791+
TensorShape({size_output}),
792+
&output_splits_t));
781793
tensorflow::TTypes<string>::Vec output_splits =
782794
output_splits_t->vec<string>();
795+
if (num_elements == 0) {
796+
return;
797+
}
783798
SplitBuilderState state(context);
799+
switch (weak_learner_type) {
800+
case LearnerConfig::NORMAL_DECISION_TREE: {
801+
ComputeNormalDecisionTree(
802+
context, &state, normalizer_ratio, num_elements,
803+
partition_boundaries, non_empty_partitions, bias_feature_id,
804+
partition_ids, feature_ids, gradients_t, hessians_t,
805+
&output_partition_ids, &gains, &output_splits);
806+
break;
807+
}
808+
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
809+
ComputeObliviousDecisionTree(
810+
context, &state, normalizer_ratio, num_elements,
811+
partition_boundaries, non_empty_partitions, bias_feature_id,
812+
partition_ids, feature_ids, gradients_t, hessians_t,
813+
&output_partition_ids, &gains, &output_splits);
814+
break;
815+
}
816+
}
817+
}
818+
819+
private:
820+
void ComputeNormalDecisionTree(
821+
OpKernelContext* const context, SplitBuilderState* state,
822+
const float normalizer_ratio, const int num_elements,
823+
const std::vector<int32>& partition_boundaries,
824+
const std::vector<int32>& non_empty_partitions,
825+
const int64 bias_feature_id,
826+
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
827+
const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
828+
const Tensor* gradients_t, const Tensor* hessians_t,
829+
tensorflow::TTypes<int32>::Vec* output_partition_ids,
830+
tensorflow::TTypes<float>::Vec* gains,
831+
tensorflow::TTypes<string>::Vec* output_splits) {
784832
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
785833
float best_gain = std::numeric_limits<float>::lowest();
786834
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -790,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
790838
errors::InvalidArgument("Bias feature ID missing."));
791839
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
792840
root_gradient_stats *= normalizer_ratio;
793-
NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
841+
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
794842
int32 best_feature_idx = 0;
795843
NodeStats best_right_node_stats(0);
796844
NodeStats best_left_node_stats(0);
@@ -801,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
801849
left_gradient_stats *= normalizer_ratio;
802850
GradientStats right_gradient_stats =
803851
root_gradient_stats - left_gradient_stats;
804-
NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
805-
NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
852+
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
853+
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
806854
if (left_stats.gain + right_stats.gain > best_gain) {
807855
best_gain = left_stats.gain + right_stats.gain;
808856
best_left_node_stats = left_stats;
@@ -813,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
813861
SplitInfo split_info;
814862
auto* equality_split = split_info.mutable_split_node()
815863
->mutable_categorical_id_binary_split();
816-
equality_split->set_feature_column(state.feature_column_group_id());
864+
equality_split->set_feature_column(state->feature_column_group_id());
817865
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
818866
auto* left_child = split_info.mutable_left_child();
819867
auto* right_child = split_info.mutable_right_child();
820-
state.FillLeaf(best_left_node_stats, left_child);
821-
state.FillLeaf(best_right_node_stats, right_child);
822-
split_info.SerializeToString(&output_splits(root_idx));
823-
gains(root_idx) =
824-
best_gain - root_stats.gain - state.tree_complexity_regularization();
825-
output_partition_ids(root_idx) = partition_ids(start_index);
868+
state->FillLeaf(best_left_node_stats, left_child);
869+
state->FillLeaf(best_right_node_stats, right_child);
870+
split_info.SerializeToString(&(*output_splits)(root_idx));
871+
(*gains)(root_idx) =
872+
best_gain - root_stats.gain - state->tree_complexity_regularization();
873+
(*output_partition_ids)(root_idx) = partition_ids(start_index);
826874
}
827875
}
876+
877+
void ComputeObliviousDecisionTree(
878+
OpKernelContext* const context, SplitBuilderState* state,
879+
const float normalizer_ratio, const int num_elements,
880+
const std::vector<int32>& partition_boundaries,
881+
const std::vector<int32>& non_empty_partitions,
882+
const int64 bias_feature_id,
883+
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
884+
const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
885+
const Tensor* gradients_t, const Tensor* hessians_t,
886+
tensorflow::TTypes<int32>::Vec* output_partition_ids,
887+
tensorflow::TTypes<float>::Vec* gains,
888+
tensorflow::TTypes<string>::Vec* output_splits) {
889+
// Holds the root stats per each node to be split.
890+
std::vector<GradientStats> current_layer_stats;
891+
current_layer_stats.reserve(num_elements);
892+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
893+
const int start_index = partition_boundaries[root_idx];
894+
// First feature ID in each partition should be the bias feature.
895+
OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
896+
errors::InvalidArgument("Bias feature ID missing."));
897+
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
898+
root_gradient_stats *= normalizer_ratio;
899+
current_layer_stats.push_back(root_gradient_stats);
900+
}
901+
float best_gain = std::numeric_limits<float>::lowest();
902+
int64 best_feature_id = 0;
903+
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
904+
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
905+
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
906+
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
907+
int64 current_feature_id = std::numeric_limits<int64>::max();
908+
int64 last_feature_id = -1;
909+
// Find the lowest feature id, this is going to be the first feature id to
910+
// try.
911+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
912+
const int start_index = partition_boundaries[root_idx];
913+
if (feature_ids(start_index + 1, 0) < current_feature_id) {
914+
current_feature_id = feature_ids(start_index + 1, 0);
915+
}
916+
}
917+
// Indexes offsets for each of the partitions that can be used to access
918+
// gradients of a partition for a current feature we consider. Start at one
919+
// beacuse the zero index is for the bias.
920+
std::vector<int> current_layer_offsets(num_elements, 1);
921+
// The idea is to try every feature id in increasing order. In each
922+
// iteration we calculate the gain of the layer using the current feature id
923+
// as split value, and we also obtain the following feature id to try.
924+
while (current_feature_id > last_feature_id) {
925+
last_feature_id = current_feature_id;
926+
int64 next_feature_id = -1;
927+
// Left gradient stats per node.
928+
std::vector<GradientStats> left_gradient_stats(num_elements);
929+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
930+
int idx =
931+
current_layer_offsets[root_idx] + partition_boundaries[root_idx];
932+
const int end_index = partition_boundaries[root_idx + 1];
933+
if (idx < end_index && feature_ids(idx, 0) == current_feature_id) {
934+
GradientStats g(*gradients_t, *hessians_t, idx);
935+
g *= normalizer_ratio;
936+
left_gradient_stats[root_idx] = g;
937+
current_layer_offsets[root_idx]++;
938+
idx++;
939+
}
940+
if (idx < end_index &&
941+
(feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) {
942+
next_feature_id = feature_ids(idx, 0);
943+
}
944+
}
945+
float gain_of_split = 0.0;
946+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
947+
GradientStats right_gradient_stats =
948+
current_layer_stats[root_idx] - left_gradient_stats[root_idx];
949+
NodeStats left_stat =
950+
state->ComputeNodeStats(left_gradient_stats[root_idx]);
951+
NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
952+
gain_of_split += left_stat.gain + right_stat.gain;
953+
current_left_node_stats[root_idx] = left_stat;
954+
current_right_node_stats[root_idx] = right_stat;
955+
}
956+
if (gain_of_split > best_gain) {
957+
best_gain = gain_of_split;
958+
best_left_node_stats = current_left_node_stats;
959+
best_right_node_stats = current_right_node_stats;
960+
best_feature_id = current_feature_id;
961+
}
962+
current_feature_id = next_feature_id;
963+
}
964+
965+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
966+
best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
967+
}
968+
best_gain -= num_elements * state->tree_complexity_regularization();
969+
970+
ObliviousSplitInfo oblivious_split_info;
971+
auto* equality_split =
972+
oblivious_split_info.mutable_split_node()
973+
->mutable_oblivious_categorical_id_binary_split();
974+
equality_split->set_feature_column(state->feature_column_group_id());
975+
equality_split->set_feature_id(best_feature_id);
976+
(*gains)(0) = best_gain;
977+
978+
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
979+
auto* left_child = oblivious_split_info.add_children();
980+
auto* right_child = oblivious_split_info.add_children();
981+
982+
state->FillLeaf(best_left_node_stats[root_idx], left_child);
983+
state->FillLeaf(best_right_node_stats[root_idx], right_child);
984+
985+
const int start_index = partition_boundaries[root_idx];
986+
(*output_partition_ids)(root_idx) = partition_ids(start_index);
987+
oblivious_split_info.add_children_parent_id(partition_ids(start_index));
988+
}
989+
oblivious_split_info.SerializeToString(&(*output_splits)(0));
990+
}
828991
};
829992

830993
REGISTER_KERNEL_BUILDER(

tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
22+
from tensorflow.contrib.boosted_trees.proto import learner_pb2
2223
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
2324
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
2425
from tensorflow.python.framework import constant_op
@@ -46,6 +47,7 @@ def __init__(self,
4647
multiclass_strategy,
4748
init_stamp_token=0,
4849
loss_uses_sum_reduction=False,
50+
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
4951
name=None):
5052
"""Initialize the internal state for this split handler.
5153
@@ -66,6 +68,7 @@ def __init__(self,
6668
stamped objects.
6769
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
6870
SUM or MEAN reduction was used for the loss.
71+
weak_learner_type: Specifies the type of weak learner to use.
6972
name: An optional handler name.
7073
"""
7174
super(EqualitySplitHandler, self).__init__(
@@ -85,6 +88,7 @@ def __init__(self,
8588
hessian_shape,
8689
name="StatsAccumulator/{}".format(self._name))
8790
self._sparse_int_column = sparse_int_column
91+
self._weak_learner_type = weak_learner_type
8892

8993
def update_stats(self, stamp_token, example_partition_ids, gradients,
9094
hessians, empty_gradients, empty_hessians, weights,
@@ -197,7 +201,8 @@ def make_splits(self, stamp_token, next_stamp_token, class_id):
197201
tree_complexity_regularization=self._tree_complexity_regularization,
198202
min_node_weight=self._min_node_weight,
199203
bias_feature_id=_BIAS_FEATURE_ID,
200-
multiclass_strategy=self._multiclass_strategy))
204+
multiclass_strategy=self._multiclass_strategy,
205+
weak_learner_type=self._weak_learner_type))
201206
# There are no warm-up rounds needed in the equality column handler. So we
202207
# always return ready.
203208
are_splits_ready = constant_op.constant(True)

0 commit comments

Comments
 (0)