Skip to content

Commit 4d01e09

Browse files
committed
Revert "Revert "hrnet classification trt 加速""
This reverts commit 8323f0b.
1 parent 8323f0b commit 4d01e09

File tree

2 files changed

+1202
-11
lines changed

2 files changed

+1202
-11
lines changed

HRNetClassification/hrnet.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
#include "NvInferPlugin.h"
1111
#include "cuda_runtime_api.h"
1212
#include "logging.h"
13+
#ifdef WIN32
14+
#include "include/dirent.h" // WIN32
15+
#else
16+
#include <dirent.h> // LINUX
17+
#endif
18+
1319

1420
using namespace nvinfer1;
1521
static Logger gLogger;
@@ -33,6 +39,29 @@ static const int OUTPUT_SIZE = 1000;
3339
}\
3440
} while (0)
3541

42+
int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
43+
DIR *p_dir = opendir(p_dir_name);
44+
if (p_dir == nullptr) {
45+
return -1;
46+
}
47+
48+
struct dirent* p_file = nullptr;
49+
while ((p_file = readdir(p_dir)) != nullptr) {
50+
if (strcmp(p_file->d_name, ".") != 0 &&
51+
strcmp(p_file->d_name, "..") != 0) {
52+
//std::string cur_file_name(p_dir_name);
53+
//cur_file_name += "/";
54+
//cur_file_name += p_file->d_name;
55+
std::string cur_file_name(p_file->d_name);
56+
file_names.push_back(cur_file_name);
57+
}
58+
}
59+
60+
closedir(p_dir);
61+
return 0;
62+
}
63+
64+
3665
// TensorRT weight files have a simple space delimited format:
3766
// [type] [size] <data x size in hex>
3867
std::map<std::string, Weights> loadWeights(const std::string file) {
@@ -998,14 +1027,13 @@ void doInference(IExecutionContext& context, float* input, float* output, int ba
9981027

9991028

10001029
int main(int argc, char** argv) {
1030+
10011031
cudaSetDevice(DEVICE);
10021032
// create a model using the API directly and serialize it to a stream
10031033
char *trtModelStream{ nullptr };
10041034
size_t size{ 0 };
10051035
std::string engine_name = "hrnet.engine";
1006-
//engine_name = "E:\\LearningCodes\\GithubRepo\\tensorrtx\\yolov5\\build\\yolov5s.wts";
1007-
argv[1] = "-d";
1008-
if (std::string(argv[1]) == "-s") {
1036+
if (argc == 2 && std::string(argv[1]) == "-s") {
10091037
IHostMemory* modelStream{ nullptr };
10101038
APIToModel(BATCH_SIZE, &modelStream);
10111039
assert(modelStream != nullptr);
@@ -1018,8 +1046,7 @@ int main(int argc, char** argv) {
10181046
modelStream->destroy();
10191047
return 0;
10201048
}
1021-
else if (std::string(argv[1]) == "-d")
1022-
{
1049+
else if (argc == 3 && std::string(argv[1]) == "-d") {
10231050
std::ifstream file(engine_name, std::ios::binary);
10241051
if (file.good()) {
10251052
file.seekg(0, file.end);
@@ -1031,14 +1058,18 @@ int main(int argc, char** argv) {
10311058
file.close();
10321059
}
10331060
}
1061+
else {
1062+
std::cerr << "arguments not right!" << std::endl;
1063+
std::cerr << "./yolov5 -s // serialize model to plan file" << std::endl;
1064+
std::cerr << "./yolov5 -d ../samples // deserialize plan file and run inference" << std::endl;
1065+
return -1;
1066+
}
10341067

10351068
std::vector<std::string> file_names;
1036-
file_names.push_back("E:\\Datasets\\tiny-imagenet-200\\tiny-imagenet-200\\val\\images\\val_41.JPEG");
1037-
//if (read_files_in_dir(argv[2], file_names) < 0) {
1038-
// std::cout << "read_files_in_dir failed." << std::endl;
1039-
// return -1;
1040-
//}
1041-
1069+
if (read_files_in_dir(argv[2], file_names) < 0) {
1070+
std::cout << "read_files_in_dir failed." << std::endl;
1071+
return -1;
1072+
}
10421073
// prepare input data ---------------------------
10431074
static float data[BATCH_SIZE * 3 * INPUT_H * INPUT_W];
10441075
//for (int i = 0; i < 3 * INPUT_H * INPUT_W; i++)

0 commit comments

Comments
 (0)