Skip to content

Commit fbad226

Browse files
authored
yolov5: upgrade to v7.0 and support instance segmentation (wang-xinyu#1177)
* add seg * yolov5 seg * update wts and readme * add todo * update readme
1 parent 165b0a4 commit fbad226

File tree

8 files changed

+501
-21
lines changed

8 files changed

+501
-21
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The basic workflow of TensorRTx is:
1515

1616
## News
1717

18+
- `18 Dec 2022`. [YOLOv5](./yolov5) upgrade to support v7.0, including instance segmention.
1819
- `12 Dec 2022`. [East-Face](https://github.com/East-Face): [UNet](./unet) upgrade to support v3.0 of [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet).
1920
- `26 Oct 2022`. [ausk](https://github.com/ausk): YoloP(You Only Look Once for Panopitic Driving Perception).
2021
- `19 Sep 2022`. [QIANXUNZDL123](https://github.com/QIANXUNZDL123) and [lindsayshuo](https://github.com/lindsayshuo): YOLOv7.
@@ -29,7 +30,6 @@ The basic workflow of TensorRTx is:
2930
- `18 Oct 2021`. [xupengao](https://github.com/xupengao): YOLOv5 updated to v6.0, supporting n/s/m/l/x/n6/s6/m6/l6/x6.
3031
- `31 Aug 2021`. [FamousDirector](https://github.com/FamousDirector): update retinaface to support TensorRT 8.0.
3132
- `27 Aug 2021`. [HaiyangPeng](https://github.com/HaiyangPeng): add a python wrapper for hrnet segmentation.
32-
- `1 Jul 2021`. [freedenS](https://github.com/freedenS): DE⫶TR: End-to-End Object Detection with Transformers. First Transformer model!
3333

3434
## Tutorials
3535

@@ -75,7 +75,7 @@ Following models are implemented.
7575
|[yolov3](./yolov3)| darknet-53, weights and pytorch implementation from [ultralytics/yolov3](https://github.com/ultralytics/yolov3) |
7676
|[yolov3-spp](./yolov3-spp)| darknet-53, weights and pytorch implementation from [ultralytics/yolov3](https://github.com/ultralytics/yolov3) |
7777
|[yolov4](./yolov4)| CSPDarknet53, weights from [AlexeyAB/darknet](https://github.com/AlexeyAB/darknet#pre-trained-models), pytorch implementation from [ultralytics/yolov3](https://github.com/ultralytics/yolov3) |
78-
|[yolov5](./yolov5)| yolov5 v1.0-v6.2, pytorch implementation from [ultralytics/yolov5](https://github.com/ultralytics/yolov5) |
78+
|[yolov5](./yolov5)| yolov5 v1.0-v7.0 of [ultralytics/yolov5](https://github.com/ultralytics/yolov5), detection, classification and instance segmentation |
7979
|[yolov7](./yolov7)| yolov7 v0.1, pytorch implementation from [WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7) |
8080
|[yolop](./yolop)| yolop, pytorch implementation from [hustvl/YOLOP](https://github.com/hustvl/YOLOP) |
8181
|[retinaface](./retinaface)| resnet50 and mobilnet0.25, weights from [biubug6/Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) |

yolov5/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include_directories(${PROJECT_SOURCE_DIR}/include)
2020
include_directories(/usr/local/cuda/include)
2121
link_directories(/usr/local/cuda/lib64)
2222
# tensorrt
23+
# TODO(Call for PR): make TRT path configurable from command line
2324
include_directories(/usr/include/x86_64-linux-gnu/)
2425
link_directories(/usr/lib/x86_64-linux-gnu/)
2526

@@ -44,8 +45,14 @@ target_link_libraries(yolov5-cls cudart)
4445
target_link_libraries(yolov5-cls myplugins)
4546
target_link_libraries(yolov5-cls ${OpenCV_LIBS})
4647

48+
cuda_add_executable(yolov5-seg calibrator.cpp yolov5_seg.cpp preprocess.cu)
49+
50+
target_link_libraries(yolov5-seg nvinfer)
51+
target_link_libraries(yolov5-seg cudart)
52+
target_link_libraries(yolov5-seg myplugins)
53+
target_link_libraries(yolov5-seg ${OpenCV_LIBS})
54+
4755
if(UNIX)
4856
add_definitions(-O2 -pthread)
4957
endif(UNIX)
5058

51-

yolov5/README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ TensorRTx inference code base for [ultralytics/yolov5](https://github.com/ultral
3333

3434
## Different versions of yolov5
3535

36-
Currently, we support yolov5 v1.0, v2.0, v3.0, v3.1, v4.0, v5.0, v6.0, v6.2
36+
Currently, we support yolov5 v1.0, v2.0, v3.0, v3.1, v4.0, v5.0, v6.0, v6.2, v7.0
3737

38+
- For yolov5 v7.0, download .pt from [yolov5 release v7.0](https://github.com/ultralytics/yolov5/releases/tag/v7.0), `git clone -b v7.0 https://github.com/ultralytics/yolov5.git` and `git clone -b yolov5-v7.0 https://github.com/wang-xinyu/tensorrtx.git`, then follow how-to-run in [tensorrtx/yolov5-v7.0](https://github.com/wang-xinyu/tensorrtx/tree/yolov5-v7.0/yolov5)
3839
- For yolov5 v6.2, download .pt from [yolov5 release v6.2](https://github.com/ultralytics/yolov5/releases/tag/v6.2), `git clone -b v6.2 https://github.com/ultralytics/yolov5.git` and `git clone -b yolov5-v6.2 https://github.com/wang-xinyu/tensorrtx.git`, then follow how-to-run in [tensorrtx/yolov5-v6.2](https://github.com/wang-xinyu/tensorrtx/tree/yolov5-v6.2/yolov5)
3940
- For yolov5 v6.0, download .pt from [yolov5 release v6.0](https://github.com/ultralytics/yolov5/releases/tag/v6.0), `git clone -b v6.0 https://github.com/ultralytics/yolov5.git` and `git clone -b yolov5-v6.0 https://github.com/wang-xinyu/tensorrtx.git`, then follow how-to-run in [tensorrtx/yolov5-v6.0](https://github.com/wang-xinyu/tensorrtx/tree/yolov5-v6.0/yolov5).
4041
- For yolov5 v5.0, download .pt from [yolov5 release v5.0](https://github.com/ultralytics/yolov5/releases/tag/v5.0), `git clone -b v5.0 https://github.com/ultralytics/yolov5.git` and `git clone -b yolov5-v5.0 https://github.com/wang-xinyu/tensorrtx.git`, then follow how-to-run in [tensorrtx/yolov5-v5.0](https://github.com/wang-xinyu/tensorrtx/tree/yolov5-v5.0/yolov5).
@@ -63,7 +64,7 @@ Currently, we support yolov5 v1.0, v2.0, v3.0, v3.1, v4.0, v5.0, v6.0, v6.2
6364

6465
```
6566
// clone code according to above #Different versions of yolov5
66-
// download https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s.pt
67+
// download https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt
6768
cp {tensorrtx}/yolov5/gen_wts.py {ultralytics}/yolov5
6869
cd {ultralytics}/yolov5
6970
python gen_wts.py -w yolov5s.pt -o yolov5s.wts
@@ -103,6 +104,10 @@ python yolov5_trt.py
103104
python yolov5_trt_cuda_python.py
104105
```
105106

107+
<p align="center">
108+
<img src="https://user-images.githubusercontent.com/15235574/78247927-4d9fac00-751e-11ea-8b1b-704a0aeb3fcf.jpg" height="360px;">
109+
</p>
110+
106111
### Classification
107112

108113
```
@@ -116,6 +121,20 @@ wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/i
116121
./yolov5-cls -d yolov5s-cls.engine ../samples
117122
```
118123

124+
### Instance Segmentation
125+
126+
```
127+
# Build and serialize TensorRT engine
128+
./yolov5-seg -s yolov5s-seg.wts yolov5s-seg.engine s
129+
130+
# Run inference
131+
./yolov5-seg -d yolov5s-seg.engine ../samples
132+
```
133+
134+
<p align="center">
135+
<img src="https://user-images.githubusercontent.com/15235574/208305921-0a2ee358-6550-4d36-bb86-867685bfe069.jpg" height="360px;">
136+
</p>
137+
119138
# INT8 Quantization
120139

121140
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh
@@ -126,9 +145,6 @@ wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/i
126145

127146
4. serialize the model and test
128147

129-
<p align="center">
130-
<img src="https://user-images.githubusercontent.com/15235574/78247927-4d9fac00-751e-11ea-8b1b-704a0aeb3fcf.jpg" height="360px;">
131-
</p>
132148

133149
## More Information
134150

yolov5/common.hpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ ILayer* convBlock(INetworkDefinition *network, std::map<std::string, Weights>& w
162162
conv1->setStrideNd(DimsHW{ s, s });
163163
conv1->setPaddingNd(DimsHW{ p, p });
164164
conv1->setNbGroups(g);
165+
conv1->setName((lname + ".conv").c_str());
165166
IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + ".bn", 1e-3);
166167

167168
// silu = x * sigmoid
@@ -273,6 +274,21 @@ ILayer* SPPF(INetworkDefinition *network, std::map<std::string, Weights>& weight
273274
return cv2;
274275
}
275276

277+
ILayer* Proto(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, int c_, int c2, std::string lname) {
278+
auto cv1 = convBlock(network, weightMap, input, c_, 3, 1, 1, lname + ".cv1");
279+
280+
auto upsample = network->addResize(*cv1->getOutput(0));
281+
assert(upsample);
282+
upsample->setResizeMode(ResizeMode::kNEAREST);
283+
const float scales[] = {1, 2, 2};
284+
upsample->setScales(scales, 3);
285+
286+
auto cv2 = convBlock(network, weightMap, *upsample->getOutput(0), c_, 3, 1, 1, lname + ".cv2");
287+
auto cv3 = convBlock(network, weightMap, *cv2->getOutput(0), c2, 1, 1, 1, lname + ".cv3");
288+
assert(cv3);
289+
return cv3;
290+
}
291+
276292
std::vector<std::vector<float>> getAnchors(std::map<std::string, Weights>& weightMap, std::string lname) {
277293
std::vector<std::vector<float>> anchors;
278294
Weights wts = weightMap[lname + ".anchor_grid"];
@@ -285,13 +301,13 @@ std::vector<std::vector<float>> getAnchors(std::map<std::string, Weights>& weigh
285301
return anchors;
286302
}
287303

288-
IPluginV2Layer* addYoLoLayer(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, std::string lname, std::vector<IConvolutionLayer*> dets) {
304+
IPluginV2Layer* addYoLoLayer(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, std::string lname, std::vector<IConvolutionLayer*> dets, bool is_segmentation = false) {
289305
auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");
290306
auto anchors = getAnchors(weightMap, lname);
291307
PluginField plugin_fields[2];
292-
int netinfo[4] = {Yolo::CLASS_NUM, Yolo::INPUT_W, Yolo::INPUT_H, Yolo::MAX_OUTPUT_BBOX_COUNT};
308+
int netinfo[5] = {Yolo::CLASS_NUM, Yolo::INPUT_W, Yolo::INPUT_H, Yolo::MAX_OUTPUT_BBOX_COUNT, (int)is_segmentation};
293309
plugin_fields[0].data = netinfo;
294-
plugin_fields[0].length = 4;
310+
plugin_fields[0].length = 5;
295311
plugin_fields[0].name = "netinfo";
296312
plugin_fields[0].type = PluginFieldType::kFLOAT32;
297313

yolov5/gen_wts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def parse_args():
1313
parser.add_argument(
1414
'-o', '--output', help='Output (.wts) file path (optional)')
1515
parser.add_argument(
16-
'-t', '--type', type=str, default='detect', choices=['detect', 'cls'],
16+
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg'],
1717
help='determines the model is detection/classification')
1818
args = parser.parse_args()
1919
if not os.path.isfile(args.weights):
@@ -37,7 +37,7 @@ def parse_args():
3737
model = torch.load(pt_file, map_location=device) # load to FP32
3838
model = model['ema' if model.get('ema') else 'model'].float()
3939

40-
if m_type == "detect":
40+
if m_type in ['detect', 'seg']:
4141
# update anchor_grid info
4242
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
4343
# model.model[-1].anchor_grid = anchor_grid

yolov5/yololayer.cu

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ using namespace Yolo;
2525

2626
namespace nvinfer1
2727
{
28-
YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel)
28+
YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<Yolo::YoloKernel>& vYoloKernel)
2929
{
3030
mClassCount = classCount;
3131
mYoloV5NetWidth = netWidth;
3232
mYoloV5NetHeight = netHeight;
3333
mMaxOutObject = maxOut;
34+
is_segmentation_ = is_segmentation;
3435
mYoloKernel = vYoloKernel;
3536
mKernelCount = vYoloKernel.size();
3637

@@ -63,6 +64,7 @@ namespace nvinfer1
6364
read(d, mYoloV5NetWidth);
6465
read(d, mYoloV5NetHeight);
6566
read(d, mMaxOutObject);
67+
read(d, is_segmentation_);
6668
mYoloKernel.resize(mKernelCount);
6769
auto kernelSize = mKernelCount * sizeof(YoloKernel);
6870
memcpy(mYoloKernel.data(), d, kernelSize);
@@ -88,6 +90,7 @@ namespace nvinfer1
8890
write(d, mYoloV5NetWidth);
8991
write(d, mYoloV5NetHeight);
9092
write(d, mMaxOutObject);
93+
write(d, is_segmentation_);
9194
auto kernelSize = mKernelCount * sizeof(YoloKernel);
9295
memcpy(d, mYoloKernel.data(), kernelSize);
9396
d += kernelSize;
@@ -97,7 +100,7 @@ namespace nvinfer1
97100

98101
size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT
99102
{
100-
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject);
103+
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject) + sizeof(is_segmentation_);
101104
}
102105

103106
int YoloLayerPlugin::initialize() TRT_NOEXCEPT
@@ -172,15 +175,15 @@ namespace nvinfer1
172175
// Clone the plugin
173176
IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT
174177
{
175-
YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel);
178+
YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, is_segmentation_, mYoloKernel);
176179
p->setPluginNamespace(mPluginNamespace);
177180
return p;
178181
}
179182

180183
__device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); };
181184

182185
__global__ void CalDetection(const float *input, float *output, int noElements,
183-
const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem)
186+
const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem, bool is_segmentation)
184187
{
185188

186189
int idx = threadIdx.x + blockDim.x * blockIdx.x;
@@ -190,14 +193,15 @@ namespace nvinfer1
190193
int bnIdx = idx / total_grid;
191194
idx = idx - total_grid * bnIdx;
192195
int info_len_i = 5 + classes;
196+
if (is_segmentation) info_len_i += 32;
193197
const float* curInput = input + bnIdx * (info_len_i * total_grid * CHECK_COUNT);
194198

195199
for (int k = 0; k < CHECK_COUNT; ++k) {
196200
float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
197201
if (box_prob < IGNORE_THRESH) continue;
198202
int class_id = 0;
199203
float max_cls_prob = 0.0;
200-
for (int i = 5; i < info_len_i; ++i) {
204+
for (int i = 5; i < 5 + classes; ++i) {
201205
float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
202206
if (p > max_cls_prob) {
203207
max_cls_prob = p;
@@ -230,6 +234,10 @@ namespace nvinfer1
230234
det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1];
231235
det->conf = box_prob * max_cls_prob;
232236
det->class_id = class_id;
237+
238+
for (int i = 0; is_segmentation && i < 32; i++) {
239+
det->mask[i] = curInput[idx + k * info_len_i * total_grid + (i + 5 + classes) * total_grid];
240+
}
233241
}
234242
}
235243

@@ -247,7 +255,7 @@ namespace nvinfer1
247255

248256
//printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight);
249257
CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >
250-
(inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem);
258+
(inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem, is_segmentation_);
251259
}
252260
}
253261

@@ -294,9 +302,10 @@ namespace nvinfer1
294302
int input_w = p_netinfo[1];
295303
int input_h = p_netinfo[2];
296304
int max_output_object_count = p_netinfo[3];
305+
bool is_segmentation = (bool)p_netinfo[4];
297306
std::vector<Yolo::YoloKernel> kernels(fc->fields[1].length);
298307
memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(Yolo::YoloKernel));
299-
YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, kernels);
308+
YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, is_segmentation, kernels);
300309
obj->setPluginNamespace(mNamespace.c_str());
301310
return obj;
302311
}

yolov5/yololayer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace Yolo
2727
float bbox[LOCATIONS];
2828
float conf; // bbox_conf * cls_conf
2929
float class_id;
30+
float mask[32];
3031
};
3132
}
3233

@@ -35,7 +36,7 @@ namespace nvinfer1
3536
class API YoloLayerPlugin : public IPluginV2IOExt
3637
{
3738
public:
38-
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
39+
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<Yolo::YoloKernel>& vYoloKernel);
3940
YoloLayerPlugin(const void* data, size_t length);
4041
~YoloLayerPlugin();
4142

@@ -96,6 +97,7 @@ namespace nvinfer1
9697
int mYoloV5NetWidth;
9798
int mYoloV5NetHeight;
9899
int mMaxOutObject;
100+
bool is_segmentation_;
99101
std::vector<Yolo::YoloKernel> mYoloKernel;
100102
void** mAnchor;
101103
};

0 commit comments

Comments
 (0)