Skip to content

Commit 1708b92

Browse files
authored
Merge pull request tensorflow#8066 from jhseu/branch_149155199
Branch 149155199
2 parents 067cba5 + 571c40a commit 1708b92

File tree

72 files changed

+2516
-706
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2516
-706
lines changed

tensorflow/c/c_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ extern "C" {
730730
struct TF_Graph {
731731
TF_Graph()
732732
: graph(OpRegistry::Global()),
733-
refiner(graph.op_registry()),
733+
refiner(graph.versions().producer(), graph.op_registry()),
734734
num_sessions(0),
735735
delete_requested(false),
736736
parent(nullptr),

tensorflow/cc/framework/scope.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
118118

119119
Scope Scope::NewRootScope() {
120120
Graph* graph = new Graph(OpRegistry::Global());
121-
ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry());
121+
ShapeRefiner* refiner =
122+
new ShapeRefiner(graph->versions().producer(), graph->op_registry());
122123
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
123124
}
124125

tensorflow/compiler/tests/binary_ops_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,20 @@ def testFloatOps(self):
132132
],
133133
equality_test=self.ListsAreClose)
134134

135+
self._testBinary(
136+
gen_nn_ops._sparse_softmax_cross_entropy_with_logits,
137+
np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8],
138+
[0.9, 1.0, 1.1, 1.2]], dtype=dtype),
139+
np.array([2, 1, 7], dtype=np.int32),
140+
expected=[
141+
np.array([1.342536, 1.442536, np.nan], dtype=dtype),
142+
np.array([[0.213838, 0.236328, -0.738817, 0.288651],
143+
[0.213838, -0.763672, 0.261183, 0.288651],
144+
[np.nan, np.nan, np.nan, np.nan]],
145+
dtype=dtype),
146+
],
147+
equality_test=self.ListsAreClose)
148+
135149
def testIntOps(self):
136150
for dtype in self.int_types:
137151
self._testBinary(

tensorflow/compiler/tests/randomized_tests.cc

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,22 +1773,13 @@ TEST_F(OpTest, Softmax) {
17731773
});
17741774
}
17751775

1776-
TEST_F(OpTest, Split) {
1776+
TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) {
17771777
Repeatedly([this]() {
1778-
DataType type = Choose<DataType>(kAllXlaTypes);
1779-
std::vector<int64> dims = RandomDims(1);
1780-
std::uniform_int_distribution<int> ud;
1781-
int32 dim = std::uniform_int_distribution<int32>(
1782-
0, static_cast<int32>(dims.size()) - 1)(generator());
1783-
int n = std::uniform_int_distribution<int>(1, 5)(generator());
1784-
// Ensure 'dim' is evenly divisible by 'n'.
1785-
dims[dim] /= n;
1786-
dims[dim] *= n;
1787-
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
1788-
.Input(test::AsScalar<int32>(dim))
1789-
.Input(RandomTensor(type, dims))
1790-
.Attr("T", type)
1791-
.Attr("num_split", n));
1778+
std::vector<int64> dims = RandomDims(2, 2, 1);
1779+
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftmaxCrossEntropyWithLogits")
1780+
.Input(RandomTensor(DT_FLOAT, dims))
1781+
.Input(RandomTensor(DT_FLOAT, dims))
1782+
.Attr("T", DT_FLOAT));
17921783
});
17931784
}
17941785

@@ -1846,6 +1837,46 @@ TEST_F(OpTest, SparseMatMul) {
18461837
});
18471838
}
18481839

1840+
TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) {
1841+
Repeatedly([this]() {
1842+
std::vector<int64> dims = RandomDims(2, 2, 1);
1843+
int64 batch_size = dims[0];
1844+
int64 num_classes = dims[1];
1845+
1846+
std::vector<int32> indices(batch_size);
1847+
for (int64 i = 0; i < batch_size; ++i) {
1848+
indices[i] =
1849+
std::uniform_int_distribution<int32>(0, num_classes - 1)(generator());
1850+
}
1851+
1852+
ExpectTfAndXlaOutputsAreClose(
1853+
OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits")
1854+
.Input(RandomTensor(DT_FLOAT, dims))
1855+
.Input(test::AsTensor<int32>(indices))
1856+
.Attr("T", DT_FLOAT)
1857+
.Attr("Tlabels", DT_INT32));
1858+
});
1859+
}
1860+
1861+
TEST_F(OpTest, Split) {
1862+
Repeatedly([this]() {
1863+
DataType type = Choose<DataType>(kAllXlaTypes);
1864+
std::vector<int64> dims = RandomDims(1);
1865+
std::uniform_int_distribution<int> ud;
1866+
int32 dim = std::uniform_int_distribution<int32>(
1867+
0, static_cast<int32>(dims.size()) - 1)(generator());
1868+
int n = std::uniform_int_distribution<int>(1, 5)(generator());
1869+
// Ensure 'dim' is evenly divisible by 'n'.
1870+
dims[dim] /= n;
1871+
dims[dim] *= n;
1872+
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
1873+
.Input(test::AsScalar<int32>(dim))
1874+
.Input(RandomTensor(type, dims))
1875+
.Attr("T", type)
1876+
.Attr("num_split", n));
1877+
});
1878+
}
1879+
18491880
TEST_F(OpTest, Sqrt) {
18501881
Repeatedly([this]() {
18511882
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt")

tensorflow/compiler/tf2xla/kernels/one_hot_op.cc

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,13 @@ limitations under the License.
1616
// XLA implementation of OneHot operator.
1717

1818
#include "tensorflow/compiler/tf2xla/literal_util.h"
19+
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
1920
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
2021
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
2122

2223
namespace tensorflow {
2324
namespace {
2425

25-
template <typename T>
26-
Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
27-
Tensor linspace(DataTypeToEnum<T>::v(), shape);
28-
auto linspace_flat = linspace.flat<T>();
29-
for (int64 i = 0; i < depth; ++i) {
30-
linspace_flat(i) = i;
31-
}
32-
return linspace;
33-
}
34-
3526
class OneHotOp : public XlaOpKernel {
3627
public:
3728
explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -71,52 +62,12 @@ class OneHotOp : public XlaOpKernel {
7162
ctx, depth >= 0,
7263
errors::InvalidArgument("depth must be non-negative, got: ", depth));
7364

74-
TensorShape output_shape = indices_shape;
75-
output_shape.InsertDim(axis, depth);
76-
77-
xla::ComputationDataHandle on_value = ctx->Input(2);
78-
xla::ComputationDataHandle off_value = ctx->Input(3);
79-
80-
// Build a Tensor populated with values 0, 1, 2, ... depth.
81-
std::vector<int64> linspace_dims(output_dims, 1);
82-
linspace_dims[axis] = depth;
83-
TensorShape linspace_shape(linspace_dims);
84-
Tensor linspace;
85-
switch (ctx->input_type(0)) {
86-
case DT_UINT8:
87-
linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
88-
break;
89-
case DT_INT32:
90-
linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
91-
break;
92-
case DT_INT64:
93-
linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
94-
break;
95-
default:
96-
ctx->SetStatus(errors::InvalidArgument(
97-
"Invalid argument type ", DataTypeString(ctx->input_type(0))));
98-
return;
99-
}
100-
xla::Literal linspace_literal;
101-
OP_REQUIRES_OK(ctx, HostTensorToLiteral(linspace, &linspace_literal));
102-
103-
xla::ComputationBuilder* builder = ctx->builder();
104-
xla::ComputationDataHandle indices = ctx->Input(0);
105-
106-
// Broadcast the linspace constant across the indices along the new axis,
107-
// and test equality at each position.
108-
std::vector<int64> broadcast_dims(indices_shape.dims());
109-
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
110-
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
111-
xla::ComputationDataHandle one_hot =
112-
builder->Eq(indices, builder->ConstantLiteral(linspace_literal),
113-
broadcast_dims);
114-
115-
// Selects the user-provided off_value and on_value values.
116-
ctx->SetOutput(
117-
0, builder->Select(
118-
one_hot, builder->Broadcast(on_value, output_shape.dim_sizes()),
119-
builder->Broadcast(off_value, output_shape.dim_sizes())));
65+
xla::ComputationDataHandle one_hot;
66+
OP_REQUIRES_OK(
67+
ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
68+
indices_shape, ctx->Input(0), ctx->Input(2),
69+
ctx->Input(3), &one_hot));
70+
ctx->SetOutput(0, one_hot);
12071
}
12172

12273
private:

tensorflow/compiler/tf2xla/kernels/softmax_op.cc

Lines changed: 117 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,51 @@ class SoftmaxOp : public XlaOpKernel {
7777
REGISTER_XLA_OP("Softmax", SoftmaxOp);
7878
REGISTER_XLA_OP("LogSoftmax", SoftmaxOp);
7979

80+
std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
81+
CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
82+
const xla::ComputationDataHandle& logits,
83+
const xla::ComputationDataHandle& labels) {
84+
const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
85+
const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
86+
87+
const int kBatchDim = 0;
88+
const int kClassDim = 1;
89+
90+
xla::ComputationBuilder* b = ctx->builder();
91+
// Find the max in each batch, resulting in a tensor of shape [batch]
92+
auto logits_max =
93+
b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
94+
95+
// Subtract the max in batch b from every element in batch b.
96+
// Broadcasts along the batch dimension.
97+
auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
98+
99+
// exp(logits - max_logits)
100+
auto exp_shifted_logits = b->Exp(shifted_logits);
101+
102+
// sum_{class} (exp(logits - max_logits))
103+
auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type),
104+
add_func, {kClassDim});
105+
106+
// log(sum(exp(logits - max_logits)))
107+
auto log_sum_exp = b->Log(sum_exp);
108+
109+
// sum(-labels *
110+
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
111+
// along classes
112+
// (The subtraction broadcasts along the batch dimension.)
113+
xla::ComputationDataHandle loss = b->Reduce(
114+
b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})),
115+
XlaHelpers::Zero(b, type), add_func, {kClassDim});
116+
117+
// backprop: prob - labels, where
118+
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
119+
// (where the division broadcasts along the batch dimension)
120+
xla::ComputationDataHandle backprop =
121+
b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
122+
return {loss, backprop};
123+
}
124+
80125
class SoftmaxXentWithLogitsOp : public XlaOpKernel {
81126
public:
82127
explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
@@ -88,65 +133,95 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
88133
OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
89134
errors::InvalidArgument(
90135
"logits and labels must be same size: logits_size=",
91-
logits_shape.DebugString(), " labels_size=",
92-
labels_shape.DebugString()));
136+
logits_shape.DebugString(),
137+
" labels_size=", labels_shape.DebugString()));
93138
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
94139
errors::InvalidArgument("logits must be 2-dimensional"));
95140
// As we already tested that both inputs have the same shape no need to
96141
// check that "labels" is a matrix too.
97142

98-
// loss is 1-D (one per example), and size is batch_size.
99-
100-
const int kBatchDim = 0;
101-
const int kClassDim = 1;
102-
103143
const DataType type = input_type(0);
104-
xla::ComputationBuilder* b = ctx->builder();
105144
auto logits = ctx->Input(0);
106145
auto labels = ctx->Input(1);
107146

108-
const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
109-
const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
110-
111-
// Find the max in each batch, resulting in a tensor of shape [batch]
112-
auto logits_max =
113-
b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
114-
115-
// Subtract the max in batch b from every element in batch b.
116-
// Broadcasts along the batch dimension.
117-
auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
118-
119-
// exp(logits - max_logits)
120-
auto exp_shifted_logits = b->Exp(shifted_logits);
121-
122-
// sum_{class} (exp(logits - max_logits))
123-
auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type),
124-
add_func, {kClassDim});
125-
126-
// log(sum(exp(logits - max_logits)))
127-
auto log_sum_exp = b->Log(sum_exp);
147+
xla::ComputationDataHandle loss, backprop;
148+
std::tie(loss, backprop) =
149+
CrossEntropyWithLogits(ctx, type, logits, labels);
150+
ctx->SetOutput(0, loss);
151+
ctx->SetOutput(1, backprop);
152+
}
153+
};
128154

129-
// sum(-labels *
130-
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
131-
// along classes
132-
// (The subtraction broadcasts along the batch dimension.)
133-
xla::ComputationDataHandle loss =
134-
b->Reduce(b->Mul(b->Neg(labels),
135-
b->Sub(shifted_logits, log_sum_exp, {kBatchDim})),
136-
XlaHelpers::Zero(b, type), add_func, {kClassDim});
155+
REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp);
137156

138-
// backprop: prob - labels, where
139-
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
140-
// (where the division broadcasts along the batch dimension)
141-
xla::ComputationDataHandle backprop =
142-
b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
157+
class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
158+
public:
159+
explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
160+
: XlaOpKernel(ctx) {}
143161

162+
void Compile(XlaOpKernelContext* ctx) override {
163+
const TensorShape logits_shape = ctx->InputShape(0);
164+
const TensorShape labels_shape = ctx->InputShape(1);
165+
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
166+
errors::InvalidArgument("logits must be 2-D, but got shape ",
167+
logits_shape.DebugString()));
168+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape),
169+
errors::InvalidArgument("labels must be 1-D, but got shape ",
170+
labels_shape.DebugString()));
171+
OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0),
172+
errors::InvalidArgument(
173+
"logits and labels must have the same first dimension, "
174+
"got logits shape ",
175+
logits_shape.DebugString(), " and labels shape ",
176+
labels_shape.DebugString()));
177+
OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0,
178+
errors::InvalidArgument(
179+
"Must have at least one class, but got logits shape ",
180+
logits_shape.DebugString()));
181+
182+
int64 batch_size = logits_shape.dim_size(0);
183+
int64 depth = logits_shape.dim_size(1);
184+
185+
DataType logits_type = input_type(0);
186+
DataType indices_type = input_type(1);
187+
188+
xla::ComputationDataHandle indices = ctx->Input(1);
189+
190+
xla::ComputationBuilder* builder = ctx->builder();
191+
xla::ComputationDataHandle labels;
192+
OP_REQUIRES_OK(ctx,
193+
XlaHelpers::OneHot(
194+
builder, depth, /*axis=*/1, input_type(1), labels_shape,
195+
indices, XlaHelpers::One(builder, logits_type),
196+
XlaHelpers::Zero(builder, logits_type), &labels));
197+
198+
// If any of the indices are out of range, we must populate the labels with
199+
// NaNs to obey the interface contract of
200+
// tf.nn.sparse_softmax_cross_entropy_with_logits.
201+
// Builds a vector of {batch_size} that is 0 if the index is in range, or
202+
// NaN otherwise; then add that vector to the labels to force out-of-range
203+
// values to NaNs.
204+
xla::ComputationDataHandle nan_or_zero = builder->Select(
205+
builder->LogicalAnd(
206+
builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
207+
builder->Lt(indices, XlaHelpers::IntegerLiteral(
208+
builder, indices_type, depth))),
209+
builder->Broadcast(XlaHelpers::Zero(builder, logits_type),
210+
{batch_size}),
211+
builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
212+
{batch_size}));
213+
labels = builder->Add(labels, nan_or_zero, {0});
214+
215+
xla::ComputationDataHandle loss, backprop;
216+
std::tie(loss, backprop) =
217+
CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
144218
ctx->SetOutput(0, loss);
145219
ctx->SetOutput(1, backprop);
146220
}
147221
};
148222

149-
REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp);
223+
REGISTER_XLA_OP("SparseSoftmaxCrossEntropyWithLogits",
224+
SparseSoftmaxXentWithLogitsOp);
150225

151226
} // namespace
152227
} // namespace tensorflow

0 commit comments

Comments
 (0)