Skip to content

Commit 70bec40

Browse files
authored
Merge pull request wang-xinyu#53 from qiuyunzhe/anchor
Prepare anchor during initialization and add multi-gpu tutorial
2 parents bd1c1ca + 5d1dc21 commit 70bec40

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

tutorials/multi_GPU_processing.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# How to Implement Multi-GPU Processing
2+
3+
Maybe you hope to take advantage of multiple GPU to make inference even faster. Here are few tips to help you deal with it! Take **YOLO V4** as an example.
4+
5+
## 1. Make custom plugin (i.e. YOLO layer and Mish layer for YOLO V4) running asynchronically.
6+
7+
To do this, we need to use CudaStream parameter in the kernels of all custom layers and use asynchronous functions.
8+
For example, in function ` forwardGpu()` of **yololayer.cu**, you need to do the following changes to make sure that the engine will be running on a specific CudaStream.
9+
10+
1) Change `cudaMemset(output + idx*outputElem, 0, sizeof(float))` to `cudaMemsetAsync(output + idx*outputElem, 0, sizeof(float), stream)`
11+
2) Change `CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount>>>(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem)` to `CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem)`
12+
13+
## 2. Create an engine for each device you want to use.
14+
15+
Maybe it is a good idea to create a struct to store the engine, context and buffer for each device individually. For example,
16+
```
17+
struct Plan{
18+
IRuntime* runtime;
19+
ICudaEngine* engine;
20+
IExecutionContext* context;
21+
void buffers[2];
22+
cudaStream_t stream;
23+
};
24+
```
25+
And then use `cudaSetDevice()` to make each engine you create running on specific device. Moreover, to maximize performance, make sure that the engine file you are using to deserialize is the one tensor RT optimized for this device.
26+
27+
## 3. Use function wisely
28+
Here are some knowledge I learned when trying to parallelize the inference.
29+
1) Do not use synchronized function , like `cudaFree()`, during inference.
30+
2) Using `cudaMallocHost()` instead of `malloc()` when allocating memory on the host side.

yolov4/yololayer.cu

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ namespace nvinfer1
1313
mYoloKernel.push_back(yolo3);
1414

1515
mKernelCount = mYoloKernel.size();
16+
17+
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
18+
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
19+
for(int ii = 0; ii < mKernelCount; ii ++)
20+
{
21+
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
22+
const auto& yolo = mYoloKernel[ii];
23+
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
24+
}
1625
}
1726

1827
YoloLayerPlugin::~YoloLayerPlugin()
@@ -32,6 +41,15 @@ namespace nvinfer1
3241
memcpy(mYoloKernel.data(),d,kernelSize);
3342
d += kernelSize;
3443

44+
CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
45+
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
46+
for(int ii = 0; ii < mKernelCount; ii ++)
47+
{
48+
CUDA_CHECK(cudaMalloc(&mAnchor[ii],AnchorLen));
49+
const auto& yolo = mYoloKernel[ii];
50+
CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
51+
}
52+
3553
assert(d == a + length);
3654
}
3755

@@ -179,9 +197,6 @@ namespace nvinfer1
179197
}
180198

181199
void YoloLayerPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
182-
void* devAnchor;
183-
size_t AnchorLen = sizeof(float)* CHECK_COUNT*2;
184-
CUDA_CHECK(cudaMalloc(&devAnchor,AnchorLen));
185200

186201
int outputElem = 1 + MAX_OUTPUT_BBOX_COUNT * sizeof(Detection) / sizeof(float);
187202

@@ -195,12 +210,10 @@ namespace nvinfer1
195210
numElem = yolo.width*yolo.height*batchSize;
196211
if (numElem < mThreadCount)
197212
mThreadCount = numElem;
198-
CUDA_CHECK(cudaMemcpy(devAnchor, yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
199213
CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount>>>
200-
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)devAnchor, mClassCount ,outputElem);
214+
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)mAnchor[i], mClassCount ,outputElem);
201215
}
202216

203-
CUDA_CHECK(cudaFree(devAnchor));
204217
}
205218

206219

yolov4/yololayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ namespace nvinfer1
116116
int mKernelCount;
117117
std::vector<Yolo::YoloKernel> mYoloKernel;
118118
int mThreadCount = 256;
119+
void** mAnchor;
119120
const char* mPluginNamespace;
120121
};
121122

0 commit comments

Comments
 (0)