TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

289 lines
9.3KB

  1. #include "cuda_utils.h"
  2. #include "logging.h"
  3. #include "utils.h"
  4. #include "model.h"
  5. #include "config.h"
  6. #include "calibrator.h"
  7. #include <iostream>
  8. #include <chrono>
  9. #include <cmath>
  10. #include <numeric>
  11. #include <opencv2/opencv.hpp>
  12. using namespace nvinfer1;
  13. static Logger gLogger;
  14. const static int kOutputSize = kClsNumClass;
  15. void batch_preprocess(std::vector<cv::Mat>& imgs, float* output) {
  16. for (size_t b = 0; b < imgs.size(); b++) {
  17. cv::Mat img;
  18. // cv::resize(imgs[b], img, cv::Size(kClsInputW, kClsInputH));
  19. img = preprocess_img(imgs[b], kClsInputW, kClsInputH);
  20. int i = 0;
  21. for (int row = 0; row < img.rows; ++row) {
  22. uchar* uc_pixel = img.data + row * img.step;
  23. for (int col = 0; col < img.cols; ++col) {
  24. output[b * 3 * img.rows * img.cols + i] = ((float)uc_pixel[2] / 255.0 - 0.485) / 0.229; // R - 0.485
  25. output[b * 3 * img.rows * img.cols + i + img.rows * img.cols] = ((float)uc_pixel[1] / 255.0 - 0.456) / 0.224;
  26. output[b * 3 * img.rows * img.cols + i + 2 * img.rows * img.cols] = ((float)uc_pixel[0] / 255.0 - 0.406) / 0.225;
  27. uc_pixel += 3;
  28. ++i;
  29. }
  30. }
  31. }
  32. }
  33. std::vector<float> softmax(float *prob, int n) {
  34. std::vector<float> res;
  35. float sum = 0.0f;
  36. float t;
  37. for (int i = 0; i < n; i++) {
  38. t = expf(prob[i]);
  39. res.push_back(t);
  40. sum += t;
  41. }
  42. for (int i = 0; i < n; i++) {
  43. res[i] /= sum;
  44. }
  45. return res;
  46. }
  47. std::vector<int> topk(const std::vector<float>& vec, int k) {
  48. std::vector<int> topk_index;
  49. std::vector<size_t> vec_index(vec.size());
  50. std::iota(vec_index.begin(), vec_index.end(), 0);
  51. std::sort(vec_index.begin(), vec_index.end(), [&vec](size_t index_1, size_t index_2) { return vec[index_1] > vec[index_2]; });
  52. int k_num = std::min<int>(vec.size(), k);
  53. for (int i = 0; i < k_num; ++i) {
  54. topk_index.push_back(vec_index[i]);
  55. }
  56. return topk_index;
  57. }
  58. std::vector<std::string> read_classes(std::string file_name) {
  59. std::vector<std::string> classes;
  60. std::ifstream ifs(file_name, std::ios::in);
  61. if (!ifs.is_open()) {
  62. std::cerr << file_name << " is not found, pls refer to README and download it." << std::endl;
  63. assert(0);
  64. }
  65. std::string s;
  66. while (std::getline(ifs, s)) {
  67. classes.push_back(s);
  68. }
  69. ifs.close();
  70. return classes;
  71. }
  72. bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir) {
  73. if (argc < 4) return false;
  74. if (std::string(argv[1]) == "-s" && (argc == 5 || argc == 7)) {
  75. wts = std::string(argv[2]);
  76. engine = std::string(argv[3]);
  77. auto net = std::string(argv[4]);
  78. if (net[0] == 'n') {
  79. gd = 0.33;
  80. gw = 0.25;
  81. } else if (net[0] == 's') {
  82. gd = 0.33;
  83. gw = 0.50;
  84. } else if (net[0] == 'm') {
  85. gd = 0.67;
  86. gw = 0.75;
  87. } else if (net[0] == 'l') {
  88. gd = 1.0;
  89. gw = 1.0;
  90. } else if (net[0] == 'x') {
  91. gd = 1.33;
  92. gw = 1.25;
  93. } else if (net[0] == 'c' && argc == 7) {
  94. gd = atof(argv[5]);
  95. gw = atof(argv[6]);
  96. } else {
  97. return false;
  98. }
  99. } else if (std::string(argv[1]) == "-d" && argc == 4) {
  100. engine = std::string(argv[2]);
  101. img_dir = std::string(argv[3]);
  102. } else {
  103. return false;
  104. }
  105. return true;
  106. }
  107. void prepare_buffers(ICudaEngine* engine, float** gpu_input_buffer, float** gpu_output_buffer, float** cpu_input_buffer, float** cpu_output_buffer) {
  108. assert(engine->getNbBindings() == 2);
  109. // In order to bind the buffers, we need to know the names of the input and output tensors.
  110. // Note that indices are guaranteed to be less than IEngine::getNbBindings()
  111. const int inputIndex = engine->getBindingIndex(kInputTensorName);
  112. const int outputIndex = engine->getBindingIndex(kOutputTensorName);
  113. assert(inputIndex == 0);
  114. assert(outputIndex == 1);
  115. // Create GPU buffers on device
  116. CUDA_CHECK(cudaMalloc((void**)gpu_input_buffer, kBatchSize * 3 * kClsInputH * kClsInputW * sizeof(float)));
  117. CUDA_CHECK(cudaMalloc((void**)gpu_output_buffer, kBatchSize * kOutputSize * sizeof(float)));
  118. *cpu_input_buffer = new float[kBatchSize * 3 * kClsInputH * kClsInputW];
  119. *cpu_output_buffer = new float[kBatchSize * kOutputSize];
  120. }
  121. void infer(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* input, float* output, int batchSize) {
  122. CUDA_CHECK(cudaMemcpyAsync(buffers[0], input, batchSize * 3 * kClsInputH * kClsInputW * sizeof(float), cudaMemcpyHostToDevice, stream));
  123. context.enqueue(batchSize, buffers, stream, nullptr);
  124. CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, stream));
  125. cudaStreamSynchronize(stream);
  126. }
  127. void serialize_engine(unsigned int max_batchsize, float& gd, float& gw, std::string& wts_name, std::string& engine_name) {
  128. // Create builder
  129. IBuilder* builder = createInferBuilder(gLogger);
  130. IBuilderConfig* config = builder->createBuilderConfig();
  131. // Create model to populate the network, then set the outputs and create an engine
  132. ICudaEngine *engine = nullptr;
  133. engine = build_cls_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);
  134. assert(engine != nullptr);
  135. // Serialize the engine
  136. IHostMemory* serialized_engine = engine->serialize();
  137. assert(serialized_engine != nullptr);
  138. // Save engine to file
  139. std::ofstream p(engine_name, std::ios::binary);
  140. if (!p) {
  141. std::cerr << "Could not open plan output file" << std::endl;
  142. assert(false);
  143. }
  144. p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());
  145. // Close everything down
  146. engine->destroy();
  147. config->destroy();
  148. serialized_engine->destroy();
  149. builder->destroy();
  150. }
  151. void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) {
  152. std::ifstream file(engine_name, std::ios::binary);
  153. if (!file.good()) {
  154. std::cerr << "read " << engine_name << " error!" << std::endl;
  155. assert(false);
  156. }
  157. size_t size = 0;
  158. file.seekg(0, file.end);
  159. size = file.tellg();
  160. file.seekg(0, file.beg);
  161. char* serialized_engine = new char[size];
  162. assert(serialized_engine);
  163. file.read(serialized_engine, size);
  164. file.close();
  165. *runtime = createInferRuntime(gLogger);
  166. assert(*runtime);
  167. *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size);
  168. assert(*engine);
  169. *context = (*engine)->createExecutionContext();
  170. assert(*context);
  171. delete[] serialized_engine;
  172. }
  173. int main(int argc, char** argv) {
  174. cudaSetDevice(kGpuId);
  175. std::string wts_name = "";
  176. std::string engine_name = "";
  177. float gd = 0.0f, gw = 0.0f;
  178. std::string img_dir;
  179. if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir)) {
  180. std::cerr << "arguments not right!" << std::endl;
  181. std::cerr << "./yolov5_cls -s [.wts] [.engine] [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl;
  182. std::cerr << "./yolov5_cls -d [.engine] ../images // deserialize plan file and run inference" << std::endl;
  183. return -1;
  184. }
  185. // Create a model using the API directly and serialize it to a file
  186. if (!wts_name.empty()) {
  187. serialize_engine(kBatchSize, gd, gw, wts_name, engine_name);
  188. return 0;
  189. }
  190. // Deserialize the engine from file
  191. IRuntime* runtime = nullptr;
  192. ICudaEngine* engine = nullptr;
  193. IExecutionContext* context = nullptr;
  194. deserialize_engine(engine_name, &runtime, &engine, &context);
  195. cudaStream_t stream;
  196. CUDA_CHECK(cudaStreamCreate(&stream));
  197. // Prepare cpu and gpu buffers
  198. float* gpu_buffers[2];
  199. float* cpu_input_buffer = nullptr;
  200. float* cpu_output_buffer = nullptr;
  201. prepare_buffers(engine, &gpu_buffers[0], &gpu_buffers[1], &cpu_input_buffer, &cpu_output_buffer);
  202. // Read images from directory
  203. std::vector<std::string> file_names;
  204. if (read_files_in_dir(img_dir.c_str(), file_names) < 0) {
  205. std::cerr << "read_files_in_dir failed." << std::endl;
  206. return -1;
  207. }
  208. // Read imagenet labels
  209. auto classes = read_classes("imagenet_classes.txt");
  210. // batch predict
  211. for (size_t i = 0; i < file_names.size(); i += kBatchSize) {
  212. // Get a batch of images
  213. std::vector<cv::Mat> img_batch;
  214. std::vector<std::string> img_name_batch;
  215. for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) {
  216. cv::Mat img = cv::imread(img_dir + "/" + file_names[j]);
  217. img_batch.push_back(img);
  218. img_name_batch.push_back(file_names[j]);
  219. }
  220. // Preprocess
  221. batch_preprocess(img_batch, cpu_input_buffer);
  222. // Run inference
  223. auto start = std::chrono::system_clock::now();
  224. infer(*context, stream, (void**)gpu_buffers, cpu_input_buffer, cpu_output_buffer, kBatchSize);
  225. auto end = std::chrono::system_clock::now();
  226. std::cout << "inference time: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  227. // Postprocess and get top-k result
  228. for (size_t b = 0; b < img_name_batch.size(); b++) {
  229. float* p = &cpu_output_buffer[b * kOutputSize];
  230. auto res = softmax(p, kOutputSize);
  231. auto topk_idx = topk(res, 3);
  232. std::cout << img_name_batch[b] << std::endl;
  233. for (auto idx: topk_idx) {
  234. std::cout << " " << classes[idx] << " " << res[idx] << std::endl;
  235. }
  236. }
  237. }
  238. // Release stream and buffers
  239. cudaStreamDestroy(stream);
  240. CUDA_CHECK(cudaFree(gpu_buffers[0]));
  241. CUDA_CHECK(cudaFree(gpu_buffers[1]));
  242. delete[] cpu_input_buffer;
  243. delete[] cpu_output_buffer;
  244. // Destroy the engine
  245. context->destroy();
  246. engine->destroy();
  247. runtime->destroy();
  248. return 0;
  249. }