Skip to content

Commit b7a754e

Browse files
authored
Fix for TensorRT 8 support (wang-xinyu#540)
1 parent 9bb2f67 commit b7a754e

File tree

5 files changed

+80
-59
lines changed

5 files changed

+80
-59
lines changed

yolov5/calibrator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ Int8EntropyCalibrator2::~Int8EntropyCalibrator2()
2626
CUDA_CHECK(cudaFree(device_input_));
2727
}
2828

29-
int Int8EntropyCalibrator2::getBatchSize() const
29+
int Int8EntropyCalibrator2::getBatchSize() const TRT_NOEXCEPT
3030
{
3131
return batchsize_;
3232
}
3333

34-
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings)
34+
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT
3535
{
3636
if (img_idx_ + batchsize_ > (int)img_files_.size()) {
3737
return false;
@@ -57,7 +57,7 @@ bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int
5757
return true;
5858
}
5959

60-
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
60+
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) TRT_NOEXCEPT
6161
{
6262
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
6363
calib_cache_.clear();
@@ -71,7 +71,7 @@ const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
7171
return length ? calib_cache_.data() : nullptr;
7272
}
7373

74-
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length)
74+
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT
7575
{
7676
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
7777
std::ofstream output(calib_table_name_, std::ios::binary);

yolov5/calibrator.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
#include <string>
66
#include <vector>
77

8+
#if NV_TENSORRT_MAJOR >= 8
9+
#define TRT_NOEXCEPT noexcept
10+
#else
11+
#define TRT_NOEXCEPT
12+
#endif
13+
814
//! \class Int8EntropyCalibrator2
915
//!
1016
//! \brief Implements Entropy calibrator 2.
@@ -16,10 +22,10 @@ class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
1622
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true);
1723

1824
virtual ~Int8EntropyCalibrator2();
19-
int getBatchSize() const override;
20-
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
21-
const void* readCalibrationCache(size_t& length) override;
22-
void writeCalibrationCache(const void* cache, size_t length) override;
25+
int getBatchSize() const TRT_NOEXCEPT override;
26+
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
27+
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
28+
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;
2329

2430
private:
2531
int batchsize_;

yolov5/logging.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
#include <sstream>
2727
#include <string>
2828

29+
#if NV_TENSORRT_MAJOR >= 8
30+
#define TRT_NOEXCEPT noexcept
31+
#else
32+
#define TRT_NOEXCEPT
33+
#endif
34+
2935
using Severity = nvinfer1::ILogger::Severity;
3036

3137
class LogStreamConsumerBuffer : public std::stringbuf
@@ -236,7 +242,7 @@ class Logger : public nvinfer1::ILogger
236242
//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
237243
//! inheritance from nvinfer1::ILogger
238244
//!
239-
void log(Severity severity, const char* msg) override
245+
void log(Severity severity, const char* msg) TRT_NOEXCEPT override
240246
{
241247
LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
242248
}

yolov5/yololayer.cu

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ namespace nvinfer1
7878
assert(d == a + length);
7979
}
8080

81-
void YoloLayerPlugin::serialize(void* buffer) const
81+
void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT
8282
{
8383
using namespace Tn;
8484
char* d = static_cast<char*>(buffer), *a = d;
@@ -95,17 +95,17 @@ namespace nvinfer1
9595
assert(d == a + getSerializationSize());
9696
}
9797

98-
size_t YoloLayerPlugin::getSerializationSize() const
98+
size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT
9999
{
100100
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject);
101101
}
102102

103-
int YoloLayerPlugin::initialize()
103+
int YoloLayerPlugin::initialize() TRT_NOEXCEPT
104104
{
105105
return 0;
106106
}
107107

108-
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
108+
Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT
109109
{
110110
//output the result to channel
111111
int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float);
@@ -114,63 +114,63 @@ namespace nvinfer1
114114
}
115115

116116
// Set plugin namespace
117-
void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace)
117+
void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT
118118
{
119119
mPluginNamespace = pluginNamespace;
120120
}
121121

122-
const char* YoloLayerPlugin::getPluginNamespace() const
122+
const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT
123123
{
124124
return mPluginNamespace;
125125
}
126126

127127
// Return the DataType of the plugin output at the requested index
128-
DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
128+
DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT
129129
{
130130
return DataType::kFLOAT;
131131
}
132132

133133
// Return true if output tensor is broadcast across a batch.
134-
bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
134+
bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT
135135
{
136136
return false;
137137
}
138138

139139
// Return true if plugin can use input that is broadcast across batch without replication.
140-
bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
140+
bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT
141141
{
142142
return false;
143143
}
144144

145-
void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
145+
void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT
146146
{
147147
}
148148

149149
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
150-
void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
150+
void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT
151151
{
152152
}
153153

154154
// Detach the plugin object from its execution context.
155-
void YoloLayerPlugin::detachFromContext() {}
155+
void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {}
156156

157-
const char* YoloLayerPlugin::getPluginType() const
157+
const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT
158158
{
159159
return "YoloLayer_TRT";
160160
}
161161

162-
const char* YoloLayerPlugin::getPluginVersion() const
162+
const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT
163163
{
164164
return "1";
165165
}
166166

167-
void YoloLayerPlugin::destroy()
167+
void YoloLayerPlugin::destroy() TRT_NOEXCEPT
168168
{
169169
delete this;
170170
}
171171

172172
// Clone the plugin
173-
IPluginV2IOExt* YoloLayerPlugin::clone() const
173+
IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT
174174
{
175175
YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel);
176176
p->setPluginNamespace(mPluginNamespace);
@@ -252,7 +252,7 @@ namespace nvinfer1
252252
}
253253

254254

255-
int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
255+
int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT
256256
{
257257
forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
258258
return 0;
@@ -269,22 +269,22 @@ namespace nvinfer1
269269
mFC.fields = mPluginAttributes.data();
270270
}
271271

272-
const char* YoloPluginCreator::getPluginName() const
272+
const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT
273273
{
274274
return "YoloLayer_TRT";
275275
}
276276

277-
const char* YoloPluginCreator::getPluginVersion() const
277+
const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT
278278
{
279279
return "1";
280280
}
281281

282-
const PluginFieldCollection* YoloPluginCreator::getFieldNames()
282+
const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT
283283
{
284284
return &mFC;
285285
}
286286

287-
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
287+
IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT
288288
{
289289
assert(fc->nbFields == 2);
290290
assert(strcmp(fc->fields[0].name, "netinfo") == 0);
@@ -301,7 +301,7 @@ namespace nvinfer1
301301
return obj;
302302
}
303303

304-
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
304+
IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT
305305
{
306306
// This object will be deleted when the network is destroyed, which will
307307
// call YoloLayerPlugin::destroy()

yolov5/yololayer.h

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
#include <string>
66
#include "NvInfer.h"
77

8+
9+
#if NV_TENSORRT_MAJOR >= 8
10+
#define TRT_NOEXCEPT noexcept
11+
#define TRT_CONST_ENQUEUE const
12+
#else
13+
#define TRT_NOEXCEPT
14+
#define TRT_CONST_ENQUEUE
15+
#endif
16+
817
namespace Yolo
918
{
1019
static constexpr int CHECK_COUNT = 3;
@@ -38,53 +47,53 @@ namespace nvinfer1
3847
YoloLayerPlugin(const void* data, size_t length);
3948
~YoloLayerPlugin();
4049

41-
int getNbOutputs() const override
50+
int getNbOutputs() const TRT_NOEXCEPT override
4251
{
4352
return 1;
4453
}
4554

46-
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
55+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
4756

48-
int initialize() override;
57+
int initialize() TRT_NOEXCEPT override;
4958

50-
virtual void terminate() override {};
59+
virtual void terminate() TRT_NOEXCEPT override {};
5160

52-
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
61+
virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; }
5362

54-
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
63+
virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
5564

56-
virtual size_t getSerializationSize() const override;
65+
virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
5766

58-
virtual void serialize(void* buffer) const override;
67+
virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
5968

60-
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
69+
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
6170
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
6271
}
6372

64-
const char* getPluginType() const override;
73+
const char* getPluginType() const TRT_NOEXCEPT override;
6574

66-
const char* getPluginVersion() const override;
75+
const char* getPluginVersion() const TRT_NOEXCEPT override;
6776

68-
void destroy() override;
77+
void destroy() TRT_NOEXCEPT override;
6978

70-
IPluginV2IOExt* clone() const override;
79+
IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
7180

72-
void setPluginNamespace(const char* pluginNamespace) override;
81+
void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
7382

74-
const char* getPluginNamespace() const override;
83+
const char* getPluginNamespace() const TRT_NOEXCEPT override;
7584

76-
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
85+
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
7786

78-
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
87+
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
7988

80-
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
89+
bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
8190

8291
void attachToContext(
83-
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
92+
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
8493

85-
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
94+
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
8695

87-
void detachFromContext() override;
96+
void detachFromContext() TRT_NOEXCEPT override;
8897

8998
private:
9099
void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
@@ -106,22 +115,22 @@ namespace nvinfer1
106115

107116
~YoloPluginCreator() override = default;
108117

109-
const char* getPluginName() const override;
118+
const char* getPluginName() const TRT_NOEXCEPT override;
110119

111-
const char* getPluginVersion() const override;
120+
const char* getPluginVersion() const TRT_NOEXCEPT override;
112121

113-
const PluginFieldCollection* getFieldNames() override;
122+
const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
114123

115-
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
124+
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
116125

117-
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
126+
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
118127

119-
void setPluginNamespace(const char* libNamespace) override
128+
void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
120129
{
121130
mNamespace = libNamespace;
122131
}
123132

124-
const char* getPluginNamespace() const override
133+
const char* getPluginNamespace() const TRT_NOEXCEPT override
125134
{
126135
return mNamespace.c_str();
127136
}

0 commit comments

Comments
 (0)