Skip to content

Commit 9001986

Browse files
committed
optimize mish in yolov4
1 parent 01690fc commit 9001986

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

yolov4/mish.cu

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,29 @@ namespace nvinfer1
4444
return DimsCHW(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
4545
}
4646

47-
__device__ float softplus(float x) { return (x > 20.0) ? x : log(1.0 + exp(x)); }
47+
__device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
48+
49+
__device__ float softplus_kernel(float x, float threshold = 20) {
50+
if (x > threshold) return x; // too large
51+
else if (x < -threshold) return expf(x); // too small
52+
return logf(expf(x) + 1);
53+
}
4854

4955
__global__ void mish_kernel(const float *input, float *output, int num_elem) {
5056

5157
int idx = threadIdx.x + blockDim.x * blockIdx.x;
5258
if (idx >= num_elem) return;
5359

54-
output[idx] = input[idx] * tanh(softplus(input[idx]));
60+
//float t = exp(input[idx]);
61+
//if (input[idx] > 20.0) {
62+
// t *= t;
63+
// output[idx] = (t - 1.0) / (t + 1.0);
64+
//} else {
65+
// float tt = t * t;
66+
// output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
67+
//}
68+
//output[idx] *= input[idx];
69+
output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
5570
}
5671

5772
void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
@@ -69,5 +84,5 @@ namespace nvinfer1
6984
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
7085
return 0;
7186
}
72-
7387
}
88+

yolov4/yolov4.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <opencv2/opencv.hpp>
1515
#include <dirent.h>
1616

17-
//#define USE_FP16 // comment out this if want to use FP32
17+
#define USE_FP16 // comment out this if want to use FP32
1818
#define DEVICE 0 // GPU id
1919
#define NMS_THRESH 0.4
2020
#define BBOX_CONF_THRESH 0.5
@@ -649,13 +649,13 @@ int main(int argc, char** argv) {
649649
// Run inference
650650
auto start = std::chrono::system_clock::now();
651651
doInference(*context, data, prob, BATCH_SIZE);
652+
auto end = std::chrono::system_clock::now();
653+
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
652654
std::vector<std::vector<Yolo::Detection>> batch_res(fcount);
653655
for (int b = 0; b < fcount; b++) {
654656
auto& res = batch_res[b];
655657
nms(res, &prob[b * OUTPUT_SIZE]);
656658
}
657-
auto end = std::chrono::system_clock::now();
658-
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
659659
for (int b = 0; b < fcount; b++) {
660660
auto& res = batch_res[b];
661661
//std::cout << res.size() << std::endl;

0 commit comments

Comments
 (0)