Skip to content

Commit 52c55be

Browse files
committed
prepare anchor during initialization
1 parent bd1c1ca commit 52c55be

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

yolov4/yololayer.cu

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ namespace nvinfer1
1313
mYoloKernel.push_back(yolo3);
1414

1515
mKernelCount = mYoloKernel.size();
16+
17+
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
18+
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
19+
for(int ii = 0; ii < mKernelCount; ii ++)
20+
{
21+
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
22+
const auto& yolo = mYoloKernel[ii];
23+
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
24+
}
1625
}
1726

1827
YoloLayerPlugin::~YoloLayerPlugin()
@@ -32,6 +41,15 @@ namespace nvinfer1
3241
memcpy(mYoloKernel.data(),d,kernelSize);
3342
d += kernelSize;
3443

44+
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
45+
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
46+
for(int ii = 0; ii < mKernelCount; ii ++)
47+
{
48+
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
49+
const auto& yolo = mYoloKernel[ii];
50+
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
51+
}
52+
3553
assert(d == a + length);
3654
}
3755

@@ -179,9 +197,6 @@ namespace nvinfer1
179197
}
180198

181199
void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
182-
void* devAnchor;
183-
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
184-
CUDA_CHECK(cudaMalloc(&devAnchor,AnchorLen));
185200

186201
int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
187202

@@ -195,12 +210,10 @@ namespace nvinfer1
195210
numElem = yolo.width*yolo.height*batchSize;
196211
if (numElem < mThreadCount)
197212
mThreadCount = numElem;
198-
CUDA_CHECK(cudaMemcpy(devAnchor, yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
199213
CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount>>>
200-
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)devAnchor, mClassCount ,outputElem);
214+
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem);
201215
}
202216

203-
CUDA_CHECK(cudaFree(devAnchor));
204217
}
205218

206219

yolov4/yololayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ namespace nvinfer1
116116
int mKernelCount;
117117
std::vector<Yolo::YoloKernel> mYoloKernel;
118118
int mThreadCount = 256;
119+
void** mAnchor;
119120
const char* mPluginNamespace;
120121
};
121122

0 commit comments

Comments
 (0)