Skip to content

Commit 63da754

Browse files
Fix tensor_forest to properly support float input weights.
Allow tensor_forest inference to pass kwargs through to trees. Change: 130940805
1 parent 4e37b20 commit 63da754

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,9 @@ class CountExtremelyRandomStats : public OpKernel {
527527
auto out_leaves = output_leaves->unaligned_flat<int32>();
528528

529529
// <accumulator, class> -> count delta
530-
PairMapType<int32> total_delta;
530+
PairMapType<float> total_delta;
531531
// <accumulator, split, class> -> count delta
532-
TupleMapType<int32> split_delta;
532+
TupleMapType<float> split_delta;
533533

534534
for (int32 i = 0; i < num_data; ++i) {
535535
out_leaves(i) = results[i].node_indices.back();

tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def testSimple(self):
7171

7272
def testSimpleWeighted(self):
7373
with self.test_session():
74-
input_weights = [1.0, 2.0, 3.0, 4.0]
74+
input_weights = [1.5, 2.0, 3.0, 4.0]
7575
(pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
7676
pcw_totals_indices, pcw_totals_sums, _,
7777
leaves) = (self.ops.count_extremely_random_stats(
@@ -90,12 +90,13 @@ def testSimpleWeighted(self):
9090
regression=False))
9191

9292
self.assertAllEqual(
93-
[[10., 1., 2., 3., 4.], [3., 1., 2., 0., 0.], [7., 0., 0., 3., 4.]],
93+
[[10.5, 1.5, 2., 3., 4.],
94+
[3.5, 1.5, 2., 0., 0.], [7., 0., 0., 3., 4.]],
9495
pcw_node_sums.eval())
9596
self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval())
96-
self.assertAllEqual([1., 1.], pcw_splits_sums.eval())
97+
self.assertAllEqual([1.5, 1.5], pcw_splits_sums.eval())
9798
self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval())
98-
self.assertAllEqual([2., 3., 1.], pcw_totals_sums.eval())
99+
self.assertAllEqual([2., 3.5, 1.5], pcw_totals_sums.eval())
99100
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
100101

101102
def testMissingLabel(self):

tensorflow/contrib/tensor_forest/python/tensor_forest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,14 @@ def training_graph(self, input_data, input_labels, data_spec=None,
379379

380380
return control_flow_ops.group(*tree_graphs, name='train')
381381

382-
def inference_graph(self, input_data, data_spec=None):
382+
def inference_graph(self, input_data, data_spec=None, **inference_args):
383383
"""Constructs a TF graph for evaluating a random forest.
384384
385385
Args:
386386
input_data: A tensor or SparseTensor or placeholder for input data.
387387
data_spec: A list of tf.dtype values specifying the original types of
388388
each column.
389+
**inference_args: Keyword arguments to pass through to each tree.
389390
390391
Returns:
391392
The last op in the random forest inference graph.
@@ -397,8 +398,8 @@ def inference_graph(self, input_data, data_spec=None):
397398
tree_data = input_data
398399
if self.params.bagged_features:
399400
tree_data = self._bag_features(i, input_data)
400-
probabilities.append(self.trees[i].inference_graph(tree_data,
401-
data_spec))
401+
probabilities.append(self.trees[i].inference_graph(
402+
tree_data, data_spec, **inference_args))
402403
with ops.device(self.device_assigner.get_device(0)):
403404
all_predict = array_ops.pack(probabilities)
404405
return math_ops.div(
@@ -415,7 +416,7 @@ def average_size(self):
415416
for i in range(self.params.num_trees):
416417
with ops.device(self.device_assigner.get_device(i)):
417418
sizes.append(self.trees[i].size())
418-
return math_ops.reduce_mean(array_ops.pack(sizes))
419+
return math_ops.reduce_mean(math_ops.to_float(array_ops.pack(sizes)))
419420

420421
# pylint: disable=unused-argument
421422
def training_loss(self, features, labels):

0 commit comments

Comments
 (0)