@@ -77,6 +77,51 @@ class SoftmaxOp : public XlaOpKernel {
7777REGISTER_XLA_OP (" Softmax" , SoftmaxOp);
7878REGISTER_XLA_OP (" LogSoftmax" , SoftmaxOp);
7979
80+ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
81+ CrossEntropyWithLogits (XlaOpKernelContext* ctx, DataType type,
82+ const xla::ComputationDataHandle& logits,
83+ const xla::ComputationDataHandle& labels) {
84+ const xla::Computation& max_func = *ctx->GetOrCreateMax (type);
85+ const xla::Computation& add_func = *ctx->GetOrCreateAdd (type);
86+
87+ const int kBatchDim = 0 ;
88+ const int kClassDim = 1 ;
89+
90+ xla::ComputationBuilder* b = ctx->builder ();
91+ // Find the max in each batch, resulting in a tensor of shape [batch]
92+ auto logits_max =
93+ b->Reduce (logits, XlaHelpers::MinValue (b, type), max_func, {kClassDim });
94+
95+ // Subtract the max in batch b from every element in batch b.
96+ // Broadcasts along the batch dimension.
97+ auto shifted_logits = b->Sub (logits, logits_max, {kBatchDim });
98+
99+ // exp(logits - max_logits)
100+ auto exp_shifted_logits = b->Exp (shifted_logits);
101+
102+ // sum_{class} (exp(logits - max_logits))
103+ auto sum_exp = b->Reduce (exp_shifted_logits, XlaHelpers::Zero (b, type),
104+ add_func, {kClassDim });
105+
106+ // log(sum(exp(logits - max_logits)))
107+ auto log_sum_exp = b->Log (sum_exp);
108+
109+ // sum(-labels *
110+ // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
111+ // along classes
112+ // (The subtraction broadcasts along the batch dimension.)
113+ xla::ComputationDataHandle loss = b->Reduce (
114+ b->Mul (b->Neg (labels), b->Sub (shifted_logits, log_sum_exp, {kBatchDim })),
115+ XlaHelpers::Zero (b, type), add_func, {kClassDim });
116+
117+ // backprop: prob - labels, where
118+ // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
119+ // (where the division broadcasts along the batch dimension)
120+ xla::ComputationDataHandle backprop =
121+ b->Sub (b->Div (exp_shifted_logits, sum_exp, {kBatchDim }), labels);
122+ return {loss, backprop};
123+ }
124+
80125class SoftmaxXentWithLogitsOp : public XlaOpKernel {
81126 public:
82127 explicit SoftmaxXentWithLogitsOp (OpKernelConstruction* ctx)
@@ -88,65 +133,95 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
88133 OP_REQUIRES (ctx, logits_shape.IsSameSize (labels_shape),
89134 errors::InvalidArgument (
90135 " logits and labels must be same size: logits_size=" ,
91- logits_shape.DebugString (), " labels_size= " ,
92- labels_shape.DebugString ()));
136+ logits_shape.DebugString (),
137+ " labels_size= " , labels_shape.DebugString ()));
93138 OP_REQUIRES (ctx, TensorShapeUtils::IsMatrix (logits_shape),
94139 errors::InvalidArgument (" logits must be 2-dimensional" ));
95140 // As we already tested that both inputs have the same shape no need to
96141 // check that "labels" is a matrix too.
97142
98- // loss is 1-D (one per example), and size is batch_size.
99-
100- const int kBatchDim = 0 ;
101- const int kClassDim = 1 ;
102-
103143 const DataType type = input_type (0 );
104- xla::ComputationBuilder* b = ctx->builder ();
105144 auto logits = ctx->Input (0 );
106145 auto labels = ctx->Input (1 );
107146
108- const xla::Computation& max_func = *ctx->GetOrCreateMax (type);
109- const xla::Computation& add_func = *ctx->GetOrCreateAdd (type);
110-
111- // Find the max in each batch, resulting in a tensor of shape [batch]
112- auto logits_max =
113- b->Reduce (logits, XlaHelpers::MinValue (b, type), max_func, {kClassDim });
114-
115- // Subtract the max in batch b from every element in batch b.
116- // Broadcasts along the batch dimension.
117- auto shifted_logits = b->Sub (logits, logits_max, {kBatchDim });
118-
119- // exp(logits - max_logits)
120- auto exp_shifted_logits = b->Exp (shifted_logits);
121-
122- // sum_{class} (exp(logits - max_logits))
123- auto sum_exp = b->Reduce (exp_shifted_logits, XlaHelpers::Zero (b, type),
124- add_func, {kClassDim });
125-
126- // log(sum(exp(logits - max_logits)))
127- auto log_sum_exp = b->Log (sum_exp);
147+ xla::ComputationDataHandle loss, backprop;
148+ std::tie (loss, backprop) =
149+ CrossEntropyWithLogits (ctx, type, logits, labels);
150+ ctx->SetOutput (0 , loss);
151+ ctx->SetOutput (1 , backprop);
152+ }
153+ };
128154
129- // sum(-labels *
130- // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
131- // along classes
132- // (The subtraction broadcasts along the batch dimension.)
133- xla::ComputationDataHandle loss =
134- b->Reduce (b->Mul (b->Neg (labels),
135- b->Sub (shifted_logits, log_sum_exp, {kBatchDim })),
136- XlaHelpers::Zero (b, type), add_func, {kClassDim });
155+ REGISTER_XLA_OP (" SoftmaxCrossEntropyWithLogits" , SoftmaxXentWithLogitsOp);
137156
138- // backprop: prob - labels, where
139- // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
140- // (where the division broadcasts along the batch dimension)
141- xla::ComputationDataHandle backprop =
142- b->Sub (b->Div (exp_shifted_logits, sum_exp, {kBatchDim }), labels);
157+ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
158+ public:
159+ explicit SparseSoftmaxXentWithLogitsOp (OpKernelConstruction* ctx)
160+ : XlaOpKernel(ctx) {}
143161
162+ void Compile (XlaOpKernelContext* ctx) override {
163+ const TensorShape logits_shape = ctx->InputShape (0 );
164+ const TensorShape labels_shape = ctx->InputShape (1 );
165+ OP_REQUIRES (ctx, TensorShapeUtils::IsMatrix (logits_shape),
166+ errors::InvalidArgument (" logits must be 2-D, but got shape " ,
167+ logits_shape.DebugString ()));
168+ OP_REQUIRES (ctx, TensorShapeUtils::IsVector (labels_shape),
169+ errors::InvalidArgument (" labels must be 1-D, but got shape " ,
170+ labels_shape.DebugString ()));
171+ OP_REQUIRES (ctx, logits_shape.dim_size (0 ) == labels_shape.dim_size (0 ),
172+ errors::InvalidArgument (
173+ " logits and labels must have the same first dimension, "
174+ " got logits shape " ,
175+ logits_shape.DebugString (), " and labels shape " ,
176+ labels_shape.DebugString ()));
177+ OP_REQUIRES (ctx, logits_shape.dim_size (1 ) > 0 ,
178+ errors::InvalidArgument (
179+ " Must have at least one class, but got logits shape " ,
180+ logits_shape.DebugString ()));
181+
182+ int64 batch_size = logits_shape.dim_size (0 );
183+ int64 depth = logits_shape.dim_size (1 );
184+
185+ DataType logits_type = input_type (0 );
186+ DataType indices_type = input_type (1 );
187+
188+ xla::ComputationDataHandle indices = ctx->Input (1 );
189+
190+ xla::ComputationBuilder* builder = ctx->builder ();
191+ xla::ComputationDataHandle labels;
192+ OP_REQUIRES_OK (ctx,
193+ XlaHelpers::OneHot (
194+ builder, depth, /* axis=*/ 1 , input_type (1 ), labels_shape,
195+ indices, XlaHelpers::One (builder, logits_type),
196+ XlaHelpers::Zero (builder, logits_type), &labels));
197+
198+ // If any of the indices are out of range, we must populate the labels with
199+ // NaNs to obey the interface contract of
200+ // tf.nn.sparse_softmax_cross_entropy_with_logits.
201+ // Builds a vector of {batch_size} that is 0 if the index is in range, or
202+ // NaN otherwise; then add that vector to the labels to force out-of-range
203+ // values to NaNs.
204+ xla::ComputationDataHandle nan_or_zero = builder->Select (
205+ builder->LogicalAnd (
206+ builder->Le (XlaHelpers::Zero (builder, indices_type), indices),
207+ builder->Lt (indices, XlaHelpers::IntegerLiteral (
208+ builder, indices_type, depth))),
209+ builder->Broadcast (XlaHelpers::Zero (builder, logits_type),
210+ {batch_size}),
211+ builder->Broadcast (XlaHelpers::FloatLiteral (builder, logits_type, NAN),
212+ {batch_size}));
213+ labels = builder->Add (labels, nan_or_zero, {0 });
214+
215+ xla::ComputationDataHandle loss, backprop;
216+ std::tie (loss, backprop) =
217+ CrossEntropyWithLogits (ctx, logits_type, ctx->Input (0 ), labels);
144218 ctx->SetOutput (0 , loss);
145219 ctx->SetOutput (1 , backprop);
146220 }
147221};
148222
149- REGISTER_XLA_OP (" SoftmaxCrossEntropyWithLogits" , SoftmaxXentWithLogitsOp);
223+ REGISTER_XLA_OP (" SparseSoftmaxCrossEntropyWithLogits" ,
224+ SparseSoftmaxXentWithLogitsOp);
150225
151226} // namespace
152227} // namespace tensorflow
0 commit comments