Skip to content

Commit 9f9109f

Browse files
give retinaface TRT8 support (wang-xinyu#698)
* give retinaface TRT8 support * add macros.h to retinaface conversion * add macros.h to retinaface conversion
1 parent 7c4476f commit 9f9109f

File tree

6 files changed

+76
-60
lines changed

6 files changed

+76
-60
lines changed

retinaface/calibrator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ Int8EntropyCalibrator2::~Int8EntropyCalibrator2()
2626
CHECK(cudaFree(device_input_));
2727
}
2828

29-
int Int8EntropyCalibrator2::getBatchSize() const
29+
int Int8EntropyCalibrator2::getBatchSize() const TRT_NOEXCEPT
3030
{
3131
return batchsize_;
3232
}
3333

34-
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings)
34+
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT
3535
{
3636
if (img_idx_ + batchsize_ > (int)img_files_.size()) {
3737
return false;
@@ -57,7 +57,7 @@ bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int
5757
return true;
5858
}
5959

60-
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
60+
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) TRT_NOEXCEPT
6161
{
6262
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
6363
calib_cache_.clear();
@@ -71,7 +71,7 @@ const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
7171
return length ? calib_cache_.data() : nullptr;
7272
}
7373

74-
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length)
74+
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT
7575
{
7676
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
7777
std::ofstream output(calib_table_name_, std::ios::binary);

retinaface/calibrator.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "NvInfer.h"
55
#include <string>
66
#include <vector>
7+
#include "macros.h"
78

89
//! \class Int8EntropyCalibrator2
910
//!
@@ -16,10 +17,10 @@ class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
1617
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true);
1718

1819
virtual ~Int8EntropyCalibrator2();
19-
int getBatchSize() const override;
20-
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
21-
const void* readCalibrationCache(size_t& length) override;
22-
void writeCalibrationCache(const void* cache, size_t length) override;
20+
int getBatchSize() const TRT_NOEXCEPT override;
21+
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
22+
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
23+
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;
2324

2425
private:
2526
int batchsize_;

retinaface/decode.cu

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ namespace nvinfer1
1616
{
1717
}
1818

19-
void DecodePlugin::serialize(void* buffer) const
19+
void DecodePlugin::serialize(void* buffer) const TRT_NOEXCEPT
2020
{
2121
}
2222

23-
size_t DecodePlugin::getSerializationSize() const
24-
{
23+
size_t DecodePlugin::getSerializationSize() const TRT_NOEXCEPT
24+
{
2525
return 0;
2626
}
2727

28-
int DecodePlugin::initialize()
28+
int DecodePlugin::initialize() TRT_NOEXCEPT
2929
{
3030
return 0;
3131
}
3232

33-
Dims DecodePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
33+
Dims DecodePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT
3434
{
3535
//output the result to channel
3636
int totalCount = 0;
@@ -42,63 +42,63 @@ namespace nvinfer1
4242
}
4343

4444
// Set plugin namespace
45-
void DecodePlugin::setPluginNamespace(const char* pluginNamespace)
45+
void DecodePlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT
4646
{
4747
mPluginNamespace = pluginNamespace;
4848
}
4949

50-
const char* DecodePlugin::getPluginNamespace() const
50+
const char* DecodePlugin::getPluginNamespace() const TRT_NOEXCEPT
5151
{
5252
return mPluginNamespace;
5353
}
5454

5555
// Return the DataType of the plugin output at the requested index
56-
DataType DecodePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
56+
DataType DecodePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT
5757
{
5858
return DataType::kFLOAT;
5959
}
6060

6161
// Return true if output tensor is broadcast across a batch.
62-
bool DecodePlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
62+
bool DecodePlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT
6363
{
6464
return false;
6565
}
6666

6767
// Return true if plugin can use input that is broadcast across batch without replication.
68-
bool DecodePlugin::canBroadcastInputAcrossBatch(int inputIndex) const
68+
bool DecodePlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT
6969
{
7070
return false;
7171
}
7272

73-
void DecodePlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
73+
void DecodePlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT
7474
{
7575
}
7676

7777
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
78-
void DecodePlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
78+
void DecodePlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT
7979
{
8080
}
8181

8282
// Detach the plugin object from its execution context.
83-
void DecodePlugin::detachFromContext() {}
83+
void DecodePlugin::detachFromContext() TRT_NOEXCEPT {}
8484

85-
const char* DecodePlugin::getPluginType() const
85+
const char* DecodePlugin::getPluginType() const TRT_NOEXCEPT
8686
{
8787
return "Decode_TRT";
8888
}
8989

90-
const char* DecodePlugin::getPluginVersion() const
90+
const char* DecodePlugin::getPluginVersion() const TRT_NOEXCEPT
9191
{
9292
return "1";
9393
}
9494

95-
void DecodePlugin::destroy()
95+
void DecodePlugin::destroy() TRT_NOEXCEPT
9696
{
9797
delete this;
9898
}
9999

100100
// Clone the plugin
101-
IPluginV2IOExt* DecodePlugin::clone() const
101+
IPluginV2IOExt* DecodePlugin::clone() const TRT_NOEXCEPT
102102
{
103103
DecodePlugin *p = new DecodePlugin();
104104
p->setPluginNamespace(mPluginNamespace);
@@ -190,7 +190,7 @@ namespace nvinfer1
190190
}
191191
}
192192

193-
int DecodePlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
193+
int DecodePlugin::enqueue(int batchSize, const void*const * inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT
194194
{
195195
//GPU
196196
//CUDA_CHECK(cudaStreamSynchronize(stream));
@@ -209,29 +209,29 @@ namespace nvinfer1
209209
mFC.fields = mPluginAttributes.data();
210210
}
211211

212-
const char* DecodePluginCreator::getPluginName() const
212+
const char* DecodePluginCreator::getPluginName() const TRT_NOEXCEPT
213213
{
214214
return "Decode_TRT";
215215
}
216216

217-
const char* DecodePluginCreator::getPluginVersion() const
217+
const char* DecodePluginCreator::getPluginVersion() const TRT_NOEXCEPT
218218
{
219219
return "1";
220220
}
221221

222-
const PluginFieldCollection* DecodePluginCreator::getFieldNames()
222+
const PluginFieldCollection* DecodePluginCreator::getFieldNames() TRT_NOEXCEPT
223223
{
224224
return &mFC;
225225
}
226226

227-
IPluginV2IOExt* DecodePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
227+
IPluginV2IOExt* DecodePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT
228228
{
229229
DecodePlugin* obj = new DecodePlugin();
230230
obj->setPluginNamespace(mNamespace.c_str());
231231
return obj;
232232
}
233233

234-
IPluginV2IOExt* DecodePluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
234+
IPluginV2IOExt* DecodePluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT
235235
{
236236
// This object will be deleted when the network is destroyed, which will
237237
// call PReluPlugin::destroy()

retinaface/decode.h

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <string>
55
#include <vector>
66
#include "NvInfer.h"
7+
#include "macros.h"
78

89
namespace decodeplugin
910
{
@@ -26,53 +27,53 @@ namespace nvinfer1
2627

2728
~DecodePlugin();
2829

29-
int getNbOutputs() const override
30+
int getNbOutputs() const TRT_NOEXCEPT override
3031
{
3132
return 1;
3233
}
3334

34-
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
35+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
3536

36-
int initialize() override;
37+
int initialize() TRT_NOEXCEPT override;
3738

38-
virtual void terminate() override {};
39+
virtual void terminate() TRT_NOEXCEPT override {};
3940

40-
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
41+
virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0;}
4142

42-
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
43+
virtual int enqueue(int batchSize, const void*const * inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
4344

44-
virtual size_t getSerializationSize() const override;
45+
virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
4546

46-
virtual void serialize(void* buffer) const override;
47+
virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
4748

48-
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
49+
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
4950
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
5051
}
5152

52-
const char* getPluginType() const override;
53+
const char* getPluginType() const TRT_NOEXCEPT override;
5354

54-
const char* getPluginVersion() const override;
55+
const char* getPluginVersion() const TRT_NOEXCEPT override;
5556

56-
void destroy() override;
57+
void destroy() TRT_NOEXCEPT override;
5758

58-
IPluginV2IOExt* clone() const override;
59+
IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
5960

60-
void setPluginNamespace(const char* pluginNamespace) override;
61+
void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
6162

62-
const char* getPluginNamespace() const override;
63+
const char* getPluginNamespace() const TRT_NOEXCEPT override;
6364

64-
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
65+
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
6566

66-
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
67+
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
6768

68-
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
69+
bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
6970

7071
void attachToContext(
71-
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
72+
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
7273

73-
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
74+
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
7475

75-
void detachFromContext() override;
76+
void detachFromContext() TRT_NOEXCEPT override;
7677

7778
int input_size_;
7879
private:
@@ -88,22 +89,22 @@ namespace nvinfer1
8889

8990
~DecodePluginCreator() override = default;
9091

91-
const char* getPluginName() const override;
92+
const char* getPluginName() const TRT_NOEXCEPT override;
9293

93-
const char* getPluginVersion() const override;
94+
const char* getPluginVersion() const TRT_NOEXCEPT override;
9495

95-
const PluginFieldCollection* getFieldNames() override;
96+
const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
9697

97-
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
98+
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
9899

99-
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
100+
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
100101

101-
void setPluginNamespace(const char* libNamespace) override
102+
void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
102103
{
103104
mNamespace = libNamespace;
104105
}
105106

106-
const char* getPluginNamespace() const override
107+
const char* getPluginNamespace() const TRT_NOEXCEPT override
107108
{
108109
return mNamespace.c_str();
109110
}

retinaface/logging.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <ostream>
2626
#include <sstream>
2727
#include <string>
28+
#include "macros.h"
29+
2830

2931
using Severity = nvinfer1::ILogger::Severity;
3032

@@ -236,7 +238,7 @@ class Logger : public nvinfer1::ILogger
236238
//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
237239
//! inheritance from nvinfer1::ILogger
238240
//!
239-
void log(Severity severity, const char* msg) override
241+
void log(Severity severity, const char* msg) TRT_NOEXCEPT override
240242
{
241243
LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
242244
}

retinaface/macros.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef __MACROS_H
2+
#define __MACROS_H
3+
4+
#if NV_TENSORRT_MAJOR >= 8
5+
#define TRT_NOEXCEPT noexcept
6+
#define TRT_CONST_ENQUEUE const
7+
#else
8+
#define TRT_NOEXCEPT
9+
#define TRT_CONST_ENQUEUE
10+
#endif
11+
12+
#endif // __MACROS_H

0 commit comments

Comments
 (0)