|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "NvInfer.h" |
| 4 | +#include <string> |
| 5 | +#include <vector> |
| 6 | +#include <iostream> |
| 7 | +#include <iterator> |
| 8 | +#include <fstream> |
| 9 | +#include <algorithm> |
| 10 | +#include "common.hpp" |
| 11 | + |
| 12 | +//! \class Int8EntropyCalibrator2 |
| 13 | +//! |
| 14 | +//! \brief Implements Entropy calibrator 2. |
| 15 | +//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. |
| 16 | +//! |
| 17 | +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { |
| 18 | + public: |
| 19 | + Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, |
| 20 | + const char* img_dir, const char* calib_table_name, |
| 21 | + const char* input_blob_name, bool read_cache = true); |
| 22 | + |
| 23 | + virtual ~Int8EntropyCalibrator2(); |
| 24 | + int getBatchSize() const override; |
| 25 | + bool getBatch(void* bindings[], const char* names[], int nbBindings) override; |
| 26 | + const void* readCalibrationCache(size_t& length) override; |
| 27 | + void writeCalibrationCache(const void* cache, size_t length) override; |
| 28 | + |
| 29 | + private: |
| 30 | + int batchsize_; |
| 31 | + int input_w_; |
| 32 | + int input_h_; |
| 33 | + int img_idx_; |
| 34 | + std::string img_dir_; |
| 35 | + std::vector<std::string> img_files_; |
| 36 | + size_t input_count_; |
| 37 | + std::string calib_table_name_; |
| 38 | + const char* input_blob_name_; |
| 39 | + bool read_cache_; |
| 40 | + void* device_input_; |
| 41 | + std::vector<char> calib_cache_; |
| 42 | +}; |
| 43 | + |
| 44 | +Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, |
| 45 | +int input_w, int input_h, const char* img_dir, |
| 46 | +const char* calib_table_name, const char* input_blob_name, |
| 47 | +bool read_cache) |
| 48 | + : batchsize_(batchsize) |
| 49 | + , input_w_(input_w) |
| 50 | + , input_h_(input_h) |
| 51 | + , img_idx_(0) |
| 52 | + , img_dir_(img_dir) |
| 53 | + , calib_table_name_(calib_table_name) |
| 54 | + , input_blob_name_(input_blob_name) |
| 55 | + , read_cache_(read_cache) { |
| 56 | + input_count_ = 3 * input_w * input_h * batchsize; |
| 57 | + CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float))); |
| 58 | + read_files_in_dir(img_dir, img_files_); |
| 59 | +} |
| 60 | + |
| 61 | +Int8EntropyCalibrator2::~Int8EntropyCalibrator2() { |
| 62 | + CUDA_CHECK(cudaFree(device_input_)); |
| 63 | +} |
| 64 | + |
| 65 | +int Int8EntropyCalibrator2::getBatchSize() const { |
| 66 | + return batchsize_; |
| 67 | +} |
| 68 | + |
| 69 | +bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) { |
| 70 | + if (img_idx_ + batchsize_ > static_cast<int>(img_files_.size())) { |
| 71 | + return false; |
| 72 | + } |
| 73 | + |
| 74 | + std::vector<float> input_imgs_(input_count_, 0); |
| 75 | + for (int i = img_idx_; i < img_idx_ + batchsize_; i++) { |
| 76 | + std::cout << img_files_[i] << " " << i << std::endl; |
| 77 | + cv::Mat temp = cv::imread(img_dir_ + img_files_[i]); |
| 78 | + if (temp.empty()) { |
| 79 | + std::cerr << "Fatal error: image cannot open!" << std::endl; |
| 80 | + return false; |
| 81 | + } |
| 82 | + preprocessImg(temp, input_w_, input_h_); |
| 83 | + for (int c = 0; c < 3; c++) { |
| 84 | + for (int h = 0; h < input_h_; h++) { |
| 85 | + for (int w = 0; w < input_w_; w++) { |
| 86 | + input_imgs_[(i-img_idx_)*input_w_*input_h_*3 + |
| 87 | + c * input_h_ * input_w_ + h * input_w_ + w] = temp.at<cv::Vec3f>(h, w)[c]; |
| 88 | + } |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + img_idx_ += batchsize_; |
| 93 | + |
| 94 | + CUDA_CHECK(cudaMemcpy(device_input_, input_imgs_.data(), input_count_ * sizeof(float), cudaMemcpyHostToDevice)); |
| 95 | + assert(!strcmp(names[0], input_blob_name_)); |
| 96 | + bindings[0] = device_input_; |
| 97 | + return true; |
| 98 | +} |
| 99 | + |
| 100 | +const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) { |
| 101 | + std::cout << "reading calib cache: " << calib_table_name_ << std::endl; |
| 102 | + calib_cache_.clear(); |
| 103 | + std::ifstream input(calib_table_name_, std::ios::binary); |
| 104 | + input >> std::noskipws; |
| 105 | + if (read_cache_ && input.good()) { |
| 106 | + std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(calib_cache_)); |
| 107 | + } |
| 108 | + length = calib_cache_.size(); |
| 109 | + return length ? calib_cache_.data() : nullptr; |
| 110 | +} |
| 111 | + |
| 112 | +void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) { |
| 113 | + std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl; |
| 114 | + std::ofstream output(calib_table_name_, std::ios::binary); |
| 115 | + output.write(reinterpret_cast<const char*>(cache), length); |
| 116 | +} |
0 commit comments