Skip to content

Commit 2ca071d

Browse files
committed
Remove double precision math from LogSigmoid too
1 parent 8a901c5 commit 2ca071d

File tree

6 files changed

+49
-22
lines changed

6 files changed

+49
-22
lines changed

LogSigmoid.cu

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,49 @@
66
template <typename T>
77
struct logSigmoid_updateOutput_functor
88
{
9-
__device__ void operator()(T *output, const T *input) const
10-
{
11-
T z = exp(-*input);
12-
*output = ScalarConvert<double, T>::to(-log(1. + z));
9+
__device__ void operator()(T *output, const T *input) const {
10+
*output = -THCNumerics<T>::log(1.f + THCNumerics<T>::exp(- *input));
1311
}
1412
};
1513

1614
template <typename T>
1715
struct logSigmoid_updateGradInput_functor
1816
{
19-
__device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const
20-
{
21-
T z = exp(-*input);
22-
*gradInput = ScalarConvert<double, T>::to(*gradOutput * z / (1. + z));
17+
__device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const {
18+
const T z = THCNumerics<T>::exp(- *input);
19+
*gradInput = *gradOutput * z / (1.f + z);
2320
}
2421
};
2522

23+
#ifdef CUDA_HALF_TENSOR
24+
template <>
25+
struct logSigmoid_updateOutput_functor<half> {
26+
__device__ __forceinline__ void operator()(half* output, const half *input) const {
27+
#ifdef CUDA_HALF_INSTRUCTIONS
28+
const half one = __float2half(1.f);
29+
*output = __hneg(THCNumerics<half>::log(one + THCNumerics<half>::exp(__hneg(*input))));
30+
#else
31+
float in = __half2float(*input);
32+
*output = __float2half(-THCNumerics<float>::log(1.f + THCNumerics<float>::exp(-in)));
33+
#endif
34+
}
35+
};
36+
37+
template <>
38+
struct logSigmoid_updateGradInput_functor<half> {
39+
__device__ __forceinline__ void operator()(half* gradInput, const half *input, const half *gradOutput) const {
40+
#ifdef CUDA_HALF_INSTRUCTIONS
41+
const half one = __float2half(1.f);
42+
const half in_exp = THCNumerics<half>::exp(__hneg(*input));
43+
*gradInput = hdiv(__hmul(*gradOutput, in_exp), __hadd(one, in_exp));
44+
#else
45+
const float in_exp = THCNumerics<float>::exp(-(__half2float(*input)));
46+
const float go = __half2float(*gradOutput);
47+
*gradInput = __float2half(go * in_exp / (1.f + in_exp));
48+
#endif
49+
}
50+
};
51+
#endif
52+
2653
#include "generic/LogSigmoid.cu"
2754
#include "THCGenerateFloatTypes.h"

Sigmoid.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
#include <THC/THCApply.cuh>
55

66
template <typename T>
7-
struct SigmoidGradInputOp {
7+
struct sigmoid_updateGradInput_functor {
88
__device__ __forceinline__ void operator()(T* gradInput, const T *output, const T *gradOutput) const {
99
*gradInput = *gradOutput * (1.f - *output) * (*output);
1010
}
1111
};
1212

1313
#ifdef CUDA_HALF_TENSOR
1414
template <>
15-
struct SigmoidGradInputOp<half> {
15+
struct sigmoid_updateGradInput_functor<half> {
1616
__device__ __forceinline__ void operator()(half* gradInput, const half *output, const half *gradOutput) const {
1717
#ifdef CUDA_HALF_INSTRUCTIONS
18-
half one = __float2half(1.f);
18+
const half one = __float2half(1.f);
1919
*gradInput = __hmul(*gradOutput, __hmul(__hadd(one, __hneg(*output)), *output));
2020
#else
21-
float out = __half2float(*output);
22-
float go = __half2float(*gradOutput);
21+
const float out = __half2float(*output);
22+
const float go = __half2float(*gradOutput);
2323
*gradInput = __float2half(go * (1.f - out) * out);
2424
#endif
2525
}

Tanh.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <THC/THCApply.cuh>
55

66
template <typename T>
7-
struct TanhGradInputOp
7+
struct tanh_updateGradInput_functor
88
{
99
__device__ __forceinline__ void operator()(T *gradInput,
1010
const T *output, const T *gradOutput) const {
@@ -14,7 +14,7 @@ struct TanhGradInputOp
1414

1515
#ifdef CUDA_HALF_TENSOR
1616
template <>
17-
struct TanhGradInputOp<half>
17+
struct tanh_updateGradInput_functor<half>
1818
{
1919
__device__ __forceinline__ void operator()(half *gradInput,
2020
const half *output, const half *gradOutput) const {
@@ -23,8 +23,8 @@ struct TanhGradInputOp<half>
2323
const half out_square = __hmul(*output, *output);
2424
*gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square)));
2525
#else
26-
float out = __half2float(*output);
27-
float go = __half2float(*gradOutput);
26+
const float out = __half2float(*output);
27+
const float go = __half2float(*gradOutput);
2828
*gradInput = __float2half(go * (1.f - out * out));
2929
#endif
3030
}

generic/Sigmoid.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ void THNN_(Sigmoid_updateGradInput)(
2020
THCTensor *gradInput,
2121
THCTensor *output)
2222
{
23-
THCUNN_check_nElement(state, input, gradOutput);
23+
THCUNN_check_nElement(state, output, gradOutput);
2424
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
2525
THCTensor_(resizeAs)(state, gradInput, output);
26-
THC_pointwiseApply3(state, gradInput, output, gradOutput, SigmoidGradInputOp<real>());
26+
THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoid_updateGradInput_functor<real>());
2727
}
2828

2929
#endif

generic/THCUNN.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ TH_API void THNN_(Sigmoid_updateOutput)(
911911

912912
TH_API void THNN_(Sigmoid_updateGradInput)(
913913
THCState *state,
914-
THCTensor *input,
914+
THCTensor *input, // [OPTIONAL]
915915
THCTensor *gradOutput,
916916
THCTensor *gradInput,
917917
THCTensor *output);
@@ -1002,7 +1002,7 @@ TH_API void THNN_(Tanh_updateOutput)(
10021002

10031003
TH_API void THNN_(Tanh_updateGradInput)(
10041004
THCState *state,
1005-
THCTensor *input,
1005+
THCTensor *input, // [OPTIONAL]
10061006
THCTensor *gradOutput,
10071007
THCTensor *gradInput,
10081008
THCTensor *output);

generic/Tanh.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void THNN_(Tanh_updateGradInput)(
2424
THCUNN_check_shape(state, output, gradOutput);
2525
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
2626
THCTensor_(resizeAs)(state, gradInput, output);
27-
THC_pointwiseApply3(state, gradInput, output, gradOutput, TanhGradInputOp<real>());
27+
THC_pointwiseApply3(state, gradInput, output, gradOutput, tanh_updateGradInput_functor<real>());
2828
}
2929

3030
#endif

0 commit comments

Comments
 (0)