TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

281 lines
10KB

  1. #include "yololayer.h"
  2. #include "cuda_utils.h"
  3. #include <cassert>
  4. #include <vector>
  5. #include <iostream>
  6. namespace Tn {
  7. template<typename T>
  8. void write(char*& buffer, const T& val) {
  9. *reinterpret_cast<T*>(buffer) = val;
  10. buffer += sizeof(T);
  11. }
  12. template<typename T>
  13. void read(const char*& buffer, T& val) {
  14. val = *reinterpret_cast<const T*>(buffer);
  15. buffer += sizeof(T);
  16. }
  17. }
  18. namespace nvinfer1 {
  19. YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel) {
  20. mClassCount = classCount;
  21. mYoloV5NetWidth = netWidth;
  22. mYoloV5NetHeight = netHeight;
  23. mMaxOutObject = maxOut;
  24. is_segmentation_ = is_segmentation;
  25. mYoloKernel = vYoloKernel;
  26. mKernelCount = vYoloKernel.size();
  27. CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
  28. size_t AnchorLen = sizeof(float)* kNumAnchor * 2;
  29. for (int ii = 0; ii < mKernelCount; ii++) {
  30. CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
  31. const auto& yolo = mYoloKernel[ii];
  32. CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
  33. }
  34. }
  35. YoloLayerPlugin::~YoloLayerPlugin() {
  36. for (int ii = 0; ii < mKernelCount; ii++) {
  37. CUDA_CHECK(cudaFree(mAnchor[ii]));
  38. }
  39. CUDA_CHECK(cudaFreeHost(mAnchor));
  40. }
  41. // create the plugin at runtime from a byte stream
  42. YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) {
  43. using namespace Tn;
  44. const char *d = reinterpret_cast<const char *>(data), *a = d;
  45. read(d, mClassCount);
  46. read(d, mThreadCount);
  47. read(d, mKernelCount);
  48. read(d, mYoloV5NetWidth);
  49. read(d, mYoloV5NetHeight);
  50. read(d, mMaxOutObject);
  51. read(d, is_segmentation_);
  52. mYoloKernel.resize(mKernelCount);
  53. auto kernelSize = mKernelCount * sizeof(YoloKernel);
  54. memcpy(mYoloKernel.data(), d, kernelSize);
  55. d += kernelSize;
  56. CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*)));
  57. size_t AnchorLen = sizeof(float)* kNumAnchor * 2;
  58. for (int ii = 0; ii < mKernelCount; ii++) {
  59. CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen));
  60. const auto& yolo = mYoloKernel[ii];
  61. CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
  62. }
  63. assert(d == a + length);
  64. }
  65. void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
  66. using namespace Tn;
  67. char* d = static_cast<char*>(buffer), *a = d;
  68. write(d, mClassCount);
  69. write(d, mThreadCount);
  70. write(d, mKernelCount);
  71. write(d, mYoloV5NetWidth);
  72. write(d, mYoloV5NetHeight);
  73. write(d, mMaxOutObject);
  74. write(d, is_segmentation_);
  75. auto kernelSize = mKernelCount * sizeof(YoloKernel);
  76. memcpy(d, mYoloKernel.data(), kernelSize);
  77. d += kernelSize;
  78. assert(d == a + getSerializationSize());
  79. }
  80. size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT {
  81. size_t s = sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount);
  82. s += sizeof(YoloKernel) * mYoloKernel.size();
  83. s += sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight);
  84. s += sizeof(mMaxOutObject) + sizeof(is_segmentation_);
  85. return s;
  86. }
  87. int YoloLayerPlugin::initialize() TRT_NOEXCEPT {
  88. return 0;
  89. }
  90. Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT {
  91. //output the result to channel
  92. int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float);
  93. return Dims3(totalsize + 1, 1, 1);
  94. }
  95. // Set plugin namespace
  96. void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT {
  97. mPluginNamespace = pluginNamespace;
  98. }
  99. const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT {
  100. return mPluginNamespace;
  101. }
  102. // Return the DataType of the plugin output at the requested index
  103. DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT {
  104. return DataType::kFLOAT;
  105. }
  106. // Return true if output tensor is broadcast across a batch.
  107. bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT {
  108. return false;
  109. }
  110. // Return true if plugin can use input that is broadcast across batch without replication.
  111. bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT {
  112. return false;
  113. }
  114. void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT {}
  115. // Attach the plugin object to an execution context and grant the plugin the access to some context resource.
  116. void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
  117. // Detach the plugin object from its execution context.
  118. void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {}
  119. const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT {
  120. return "YoloLayer_TRT";
  121. }
  122. const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT {
  123. return "1";
  124. }
  125. void YoloLayerPlugin::destroy() TRT_NOEXCEPT {
  126. delete this;
  127. }
  128. // Clone the plugin
  129. IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT {
  130. YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, is_segmentation_, mYoloKernel);
  131. p->setPluginNamespace(mPluginNamespace);
  132. return p;
  133. }
  134. __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); };
  135. __global__ void CalDetection(const float *input, float *output, int noElements,
  136. const int netwidth, const int netheight, int maxoutobject, int yoloWidth,
  137. int yoloHeight, const float anchors[kNumAnchor * 2], int classes, int outputElem, bool is_segmentation) {
  138. int idx = threadIdx.x + blockDim.x * blockIdx.x;
  139. if (idx >= noElements) return;
  140. int total_grid = yoloWidth * yoloHeight;
  141. int bnIdx = idx / total_grid;
  142. idx = idx - total_grid * bnIdx;
  143. int info_len_i = 5 + classes;
  144. if (is_segmentation) info_len_i += 32;
  145. const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);
  146. for (int k = 0; k < kNumAnchor; ++k) {
  147. float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
  148. if (box_prob < kIgnoreThresh) continue;
  149. int class_id = 0;
  150. float max_cls_prob = 0.0;
  151. for (int i = 5; i < 5 + classes; ++i) {
  152. float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
  153. if (p > max_cls_prob) {
  154. max_cls_prob = p;
  155. class_id = i - 5;
  156. }
  157. }
  158. float *res_count = output + bnIdx * outputElem;
  159. int count = (int)atomicAdd(res_count, 1);
  160. if (count >= maxoutobject) return;
  161. char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection);
  162. Detection *det = (Detection*)(data);
  163. int row = idx / yoloWidth;
  164. int col = idx % yoloWidth;
  165. det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth;
  166. det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight;
  167. det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]);
  168. det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k];
  169. det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]);
  170. det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1];
  171. det->conf = box_prob * max_cls_prob;
  172. det->class_id = class_id;
  173. for (int i = 0; is_segmentation && i < 32; i++) {
  174. det->mask[i] = curInput[idx + k * info_len_i * total_grid + (i + 5 + classes) * total_grid];
  175. }
  176. }
  177. }
  178. void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize) {
  179. int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
  180. for (int idx = 0; idx < batchSize; ++idx) {
  181. CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
  182. }
  183. int numElem = 0;
  184. for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
  185. const auto& yolo = mYoloKernel[i];
  186. numElem = yolo.width * yolo.height * batchSize;
  187. if (numElem < mThreadCount) mThreadCount = numElem;
  188. CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >
  189. (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem, is_segmentation_);
  190. }
  191. }
  192. int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
  193. forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
  194. return 0;
  195. }
  196. PluginFieldCollection YoloPluginCreator::mFC{};
  197. std::vector<PluginField> YoloPluginCreator::mPluginAttributes;
  198. YoloPluginCreator::YoloPluginCreator() {
  199. mPluginAttributes.clear();
  200. mFC.nbFields = mPluginAttributes.size();
  201. mFC.fields = mPluginAttributes.data();
  202. }
  203. const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT {
  204. return "YoloLayer_TRT";
  205. }
  206. const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT {
  207. return "1";
  208. }
  209. const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT {
  210. return &mFC;
  211. }
  212. IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT {
  213. assert(fc->nbFields == 2);
  214. assert(strcmp(fc->fields[0].name, "netinfo") == 0);
  215. assert(strcmp(fc->fields[1].name, "kernels") == 0);
  216. int *p_netinfo = (int*)(fc->fields[0].data);
  217. int class_count = p_netinfo[0];
  218. int input_w = p_netinfo[1];
  219. int input_h = p_netinfo[2];
  220. int max_output_object_count = p_netinfo[3];
  221. bool is_segmentation = (bool)p_netinfo[4];
  222. std::vector<YoloKernel> kernels(fc->fields[1].length);
  223. memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(YoloKernel));
  224. YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, is_segmentation, kernels);
  225. obj->setPluginNamespace(mNamespace.c_str());
  226. return obj;
  227. }
  228. IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT {
  229. // This object will be deleted when the network is destroyed, which will
  230. // call YoloLayerPlugin::destroy()
  231. YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength);
  232. obj->setPluginNamespace(mNamespace.c_str());
  233. return obj;
  234. }
  235. }