TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolov5_seg.cpp 8.6KB

10 달 전
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #include "config.h"
  2. #include "cuda_utils.h"
  3. #include "logging.h"
  4. #include "utils.h"
  5. #include "preprocess.h"
  6. #include "postprocess.h"
  7. #include "model.h"
  8. #include <iostream>
  9. #include <chrono>
  10. #include <cmath>
  11. using namespace nvinfer1;
  12. static Logger gLogger;
  13. const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
  14. const static int kOutputSize2 = 32 * (kInputH / 4) * (kInputW / 4);
  15. bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename) {
  16. if (argc < 4) return false;
  17. if (std::string(argv[1]) == "-s" && (argc == 5 || argc == 7)) {
  18. wts = std::string(argv[2]);
  19. engine = std::string(argv[3]);
  20. auto net = std::string(argv[4]);
  21. if (net[0] == 'n') {
  22. gd = 0.33;
  23. gw = 0.25;
  24. } else if (net[0] == 's') {
  25. gd = 0.33;
  26. gw = 0.50;
  27. } else if (net[0] == 'm') {
  28. gd = 0.67;
  29. gw = 0.75;
  30. } else if (net[0] == 'l') {
  31. gd = 1.0;
  32. gw = 1.0;
  33. } else if (net[0] == 'x') {
  34. gd = 1.33;
  35. gw = 1.25;
  36. } else if (net[0] == 'c' && argc == 7) {
  37. gd = atof(argv[5]);
  38. gw = atof(argv[6]);
  39. } else {
  40. return false;
  41. }
  42. } else if (std::string(argv[1]) == "-d" && argc == 5) {
  43. engine = std::string(argv[2]);
  44. img_dir = std::string(argv[3]);
  45. labels_filename = std::string(argv[4]);
  46. } else {
  47. return false;
  48. }
  49. return true;
  50. }
  51. void prepare_buffers(ICudaEngine* engine, float** gpu_input_buffer, float** gpu_output_buffer1, float** gpu_output_buffer2, float** cpu_output_buffer1, float** cpu_output_buffer2) {
  52. assert(engine->getNbBindings() == 3);
  53. // In order to bind the buffers, we need to know the names of the input and output tensors.
  54. // Note that indices are guaranteed to be less than IEngine::getNbBindings()
  55. const int inputIndex = engine->getBindingIndex(kInputTensorName);
  56. const int outputIndex1 = engine->getBindingIndex(kOutputTensorName);
  57. const int outputIndex2 = engine->getBindingIndex("proto");
  58. assert(inputIndex == 0);
  59. assert(outputIndex1 == 1);
  60. assert(outputIndex2 == 2);
  61. // Create GPU buffers on device
  62. CUDA_CHECK(cudaMalloc((void**)gpu_input_buffer, kBatchSize * 3 * kInputH * kInputW * sizeof(float)));
  63. CUDA_CHECK(cudaMalloc((void**)gpu_output_buffer1, kBatchSize * kOutputSize1 * sizeof(float)));
  64. CUDA_CHECK(cudaMalloc((void**)gpu_output_buffer2, kBatchSize * kOutputSize2 * sizeof(float)));
  65. // Alloc CPU buffers
  66. *cpu_output_buffer1 = new float[kBatchSize * kOutputSize1];
  67. *cpu_output_buffer2 = new float[kBatchSize * kOutputSize2];
  68. }
  69. void infer(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output1, float* output2, int batchSize) {
  70. context.enqueue(batchSize, buffers, stream, nullptr);
  71. CUDA_CHECK(cudaMemcpyAsync(output1, buffers[1], batchSize * kOutputSize1 * sizeof(float), cudaMemcpyDeviceToHost, stream));
  72. CUDA_CHECK(cudaMemcpyAsync(output2, buffers[2], batchSize * kOutputSize2 * sizeof(float), cudaMemcpyDeviceToHost, stream));
  73. cudaStreamSynchronize(stream);
  74. }
  75. void serialize_engine(unsigned int max_batchsize, float& gd, float& gw, std::string& wts_name, std::string& engine_name) {
  76. // Create builder
  77. IBuilder* builder = createInferBuilder(gLogger);
  78. IBuilderConfig* config = builder->createBuilderConfig();
  79. // Create model to populate the network, then set the outputs and create an engine
  80. ICudaEngine *engine = nullptr;
  81. engine = build_seg_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);
  82. assert(engine != nullptr);
  83. // Serialize the engine
  84. IHostMemory* serialized_engine = engine->serialize();
  85. assert(serialized_engine != nullptr);
  86. // Save engine to file
  87. std::ofstream p(engine_name, std::ios::binary);
  88. if (!p) {
  89. std::cerr << "Could not open plan output file" << std::endl;
  90. assert(false);
  91. }
  92. p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());
  93. // Close everything down
  94. engine->destroy();
  95. config->destroy();
  96. serialized_engine->destroy();
  97. builder->destroy();
  98. }
  99. void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) {
  100. std::ifstream file(engine_name, std::ios::binary);
  101. if (!file.good()) {
  102. std::cerr << "read " << engine_name << " error!" << std::endl;
  103. assert(false);
  104. }
  105. size_t size = 0;
  106. file.seekg(0, file.end);
  107. size = file.tellg();
  108. file.seekg(0, file.beg);
  109. char* serialized_engine = new char[size];
  110. assert(serialized_engine);
  111. file.read(serialized_engine, size);
  112. file.close();
  113. *runtime = createInferRuntime(gLogger);
  114. assert(*runtime);
  115. *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size);
  116. assert(*engine);
  117. *context = (*engine)->createExecutionContext();
  118. assert(*context);
  119. delete[] serialized_engine;
  120. }
  121. int main(int argc, char** argv) {
  122. cudaSetDevice(kGpuId);
  123. std::string wts_name = "";
  124. std::string engine_name = "";
  125. std::string labels_filename = "";
  126. float gd = 0.0f, gw = 0.0f;
  127. std::string img_dir;
  128. if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename)) {
  129. std::cerr << "arguments not right!" << std::endl;
  130. std::cerr << "./yolov5_seg -s [.wts] [.engine] [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl;
  131. std::cerr << "./yolov5_seg -d [.engine] ../images coco.txt // deserialize plan file, read the labels file and run inference" << std::endl;
  132. return -1;
  133. }
  134. // Create a model using the API directly and serialize it to a file
  135. if (!wts_name.empty()) {
  136. serialize_engine(kBatchSize, gd, gw, wts_name, engine_name);
  137. return 0;
  138. }
  139. // Deserialize the engine from file
  140. IRuntime* runtime = nullptr;
  141. ICudaEngine* engine = nullptr;
  142. IExecutionContext* context = nullptr;
  143. deserialize_engine(engine_name, &runtime, &engine, &context);
  144. cudaStream_t stream;
  145. CUDA_CHECK(cudaStreamCreate(&stream));
  146. // Init CUDA preprocessing
  147. cuda_preprocess_init(kMaxInputImageSize);
  148. // Prepare cpu and gpu buffers
  149. float* gpu_buffers[3];
  150. float* cpu_output_buffer1 = nullptr;
  151. float* cpu_output_buffer2 = nullptr;
  152. prepare_buffers(engine, &gpu_buffers[0], &gpu_buffers[1], &gpu_buffers[2], &cpu_output_buffer1, &cpu_output_buffer2);
  153. // Read images from directory
  154. std::vector<std::string> file_names;
  155. if (read_files_in_dir(img_dir.c_str(), file_names) < 0) {
  156. std::cerr << "read_files_in_dir failed." << std::endl;
  157. return -1;
  158. }
  159. // Read the txt file for classnames
  160. std::ifstream labels_file(labels_filename, std::ios::binary);
  161. if (!labels_file.good()) {
  162. std::cerr << "read " << labels_filename << " error!" << std::endl;
  163. return -1;
  164. }
  165. std::unordered_map<int, std::string> labels_map;
  166. read_labels(labels_filename, labels_map);
  167. assert(kNumClass == labels_map.size());
  168. // batch predict
  169. for (size_t i = 0; i < file_names.size(); i += kBatchSize) {
  170. // Get a batch of images
  171. std::vector<cv::Mat> img_batch;
  172. std::vector<std::string> img_name_batch;
  173. for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) {
  174. cv::Mat img = cv::imread(img_dir + "/" + file_names[j]);
  175. img_batch.push_back(img);
  176. img_name_batch.push_back(file_names[j]);
  177. }
  178. // Preprocess
  179. cuda_batch_preprocess(img_batch, gpu_buffers[0], kInputW, kInputH, stream);
  180. // Run inference
  181. auto start = std::chrono::system_clock::now();
  182. infer(*context, stream, (void**)gpu_buffers, cpu_output_buffer1, cpu_output_buffer2, kBatchSize);
  183. auto end = std::chrono::system_clock::now();
  184. std::cout << "inference time: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  185. // NMS
  186. std::vector<std::vector<Detection>> res_batch;
  187. batch_nms(res_batch, cpu_output_buffer1, img_batch.size(), kOutputSize1, kConfThresh, kNmsThresh);
  188. // Draw result and save image
  189. for (size_t b = 0; b < img_name_batch.size(); b++) {
  190. auto& res = res_batch[b];
  191. cv::Mat img = img_batch[b];
  192. auto masks = process_mask(&cpu_output_buffer2[b * kOutputSize2], kOutputSize2, res);
  193. draw_mask_bbox(img, res, masks, labels_map);
  194. cv::imwrite("_" + img_name_batch[b], img);
  195. }
  196. }
  197. // Release stream and buffers
  198. cudaStreamDestroy(stream);
  199. CUDA_CHECK(cudaFree(gpu_buffers[0]));
  200. CUDA_CHECK(cudaFree(gpu_buffers[1]));
  201. CUDA_CHECK(cudaFree(gpu_buffers[2]));
  202. delete[] cpu_output_buffer1;
  203. delete[] cpu_output_buffer2;
  204. cuda_preprocess_destroy();
  205. // Destroy the engine
  206. context->destroy();
  207. engine->destroy();
  208. runtime->destroy();
  209. return 0;
  210. }