Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from models.experimental import attempt_load
  2. import tensorrt as trt
  3. import sys
  4. from segutils.trtUtils import yolov5Trtforward
  5. from utilsK.queRiver import getDetectionsFromPreds,img_pad
  6. from utils.datasets import letterbox
  7. import numpy as np
  8. import torch,time
  9. def score_filter_byClass(pdetections,score_para_2nd):
  10. ret=[]
  11. for det in pdetections:
  12. score,cls = det[4],det[5]
  13. if int(cls) in score_para_2nd.keys():
  14. score_th = score_para_2nd[int(cls)]
  15. elif str(int(cls)) in score_para_2nd.keys():
  16. score_th = score_para_2nd[str(int(cls))]
  17. else:
  18. score_th = 0.7
  19. if score > score_th:
  20. ret.append(det)
  21. return ret
  22. class yolov5Model(object):
  23. def __init__(self, weights=None,par={}):
  24. self.par = par
  25. self.device = par['device']
  26. self.half =par['half']
  27. if weights.endswith('.engine'):
  28. self. infer_type ='trt'
  29. elif weights.endswith('.pth') or weights.endswith('.pt') :
  30. self. infer_type ='pth'
  31. else:
  32. print('#########ERROR:',weights,': no registered inference type, exit')
  33. sys.exit(0)
  34. if self.infer_type=='trt':
  35. logger = trt.Logger(trt.Logger.ERROR)
  36. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  37. self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  38. #print('####load TRT model :%s'%(weights))
  39. elif self.infer_type=='pth':
  40. self.model = attempt_load(weights, map_location=self.device) # load FP32 model
  41. if self.half: self.model.half()
  42. if 'score_byClass' in par.keys(): self.score_byClass = par['score_byClass']
  43. else: self.score_byClass = None
  44. print('#########加载模型:',weights,' 类型:',self.infer_type)
  45. def eval(self,image):
  46. t0=time.time()
  47. img = self.preprocess_image(image)
  48. t1=time.time()
  49. if self.infer_type=='trt':
  50. pred = yolov5Trtforward(self.model,img)
  51. else:
  52. pred = self.model(img,augment=False)[0]
  53. t2=time.time()
  54. if 'ovlap_thres_crossCategory' in self.par.keys():
  55. ovlap_thres = self.par['ovlap_thres_crossCategory']
  56. else:
  57. ovlap_thres = None
  58. p_result, timeOut = getDetectionsFromPreds(pred,img,image,conf_thres=self.par['conf_thres'],iou_thres=self.par['iou_thres'],ovlap_thres=ovlap_thres,padInfos=self.padInfos)
  59. if self.score_byClass:
  60. p_result[2] = score_filter_byClass(p_result[2],self.score_byClass)
  61. t3=time.time()
  62. timeOut = 'yolov5 :%.1f (pre-process:%.1f, inference:%.1f, post-process:%.1f) '%( self.get_ms(t3,t0) , self.get_ms(t1,t0) , self.get_ms(t2,t1) , self.get_ms(t3,t2) )
  63. return p_result[2], timeOut
  64. def get_ms(self,t1,t0):
  65. return (t1-t0)*1000.0
  66. def preprocess_image(self,image):
  67. if self.infer_type=='trt':
  68. img, padInfos = img_pad( image , size=(640,640,3)) ;img = [img]
  69. self.padInfos =padInfos
  70. else:
  71. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in [image]];
  72. self.padInfos=None
  73. # Stack
  74. img = np.stack(img, 0)
  75. # Convert
  76. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  77. img = np.ascontiguousarray(img)
  78. img = torch.from_numpy(img).to(self.device)
  79. img = img.half() if self.half else img.float() # uint8 to fp16/32
  80. img /= 255.0
  81. return img