@@ -44,14 +44,29 @@ namespace nvinfer1
4444 return DimsCHW (inputs[0 ].d [0 ], inputs[0 ].d [1 ], inputs[0 ].d [2 ]);
4545 }
4646
47- __device__ float softplus (float x) { return (x > 20.0 ) ? x : log (1.0 + exp (x)); }
47+ __device__ float tanh_activate_kernel (float x){return (2 /(1 + expf (-2 *x)) - 1 );}
48+
49+ __device__ float softplus_kernel (float x, float threshold = 20 ) {
50+ if (x > threshold) return x; // too large
51+ else if (x < -threshold) return expf (x); // too small
52+ return logf (expf (x) + 1 );
53+ }
4854
4955 __global__ void mish_kernel (const float *input, float *output, int num_elem) {
5056
5157 int idx = threadIdx .x + blockDim .x * blockIdx .x ;
5258 if (idx >= num_elem) return ;
5359
54- output[idx] = input[idx] * tanh (softplus (input[idx]));
60+ // float t = exp(input[idx]);
61+ // if (input[idx] > 20.0) {
62+ // t *= t;
63+ // output[idx] = (t - 1.0) / (t + 1.0);
64+ // } else {
65+ // float tt = t * t;
66+ // output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
67+ // }
68+ // output[idx] *= input[idx];
69+ output[idx] = input[idx] * tanh_activate_kernel (softplus_kernel (input[idx]));
5570 }
5671
5772 void MishPlugin::forwardGpu (const float *const * inputs, float * output, cudaStream_t stream, int batchSize) {
@@ -69,5 +84,5 @@ namespace nvinfer1
6984 forwardGpu ((const float *const *)inputs, (float *)outputs[0 ], stream, batchSize);
7085 return 0 ;
7186 }
72-
7387}
88+
0 commit comments