Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

1 ano atrás
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import cv2,os,time,json
  2. from models.experimental import attempt_load
  3. from segutils.segmodel import SegModel,get_largest_contours
  4. from segutils.trtUtils import segtrtEval,yolov5Trtforward
  5. from utils.torch_utils import select_device
  6. from utilsK.queRiver import get_labelnames,get_label_arrays,post_process_,img_pad
  7. from utils.datasets import letterbox
  8. import numpy as np
  9. import torch
  10. def get_postProcess_para(parfile):
  11. with open(parfile) as fp:
  12. par = json.load(fp)
  13. assert 'post_process' in par.keys(), ' parfile has not key word:post_process'
  14. parPost=par['post_process']
  15. return parPost["conf_thres"],parPost["iou_thres"],parPost["classes"],parPost["rainbows"]
  16. def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'slopeIndex':[5,6,7],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False }, font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,segPar={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True}):
  17. #输入参数
  18. # im0s---原始图像列表
  19. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  20. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  21. # [im0s[0],im0,det_xywh,iframe]中,
  22. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  23. # det_xywh--检测结果,是一个列表。
  24. # 其中每一个元素表示一个目标构成如:[float(cls_c), xc,yc,w,h, float(conf_c)]
  25. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  26. # #strout---统计AI处理个环节的时间
  27. # Letterbox
  28. half,device,conf_thres,iou_thres,allowedList = objectPar['half'],objectPar['device'],objectPar['conf_thres'],objectPar['iou_thres'],objectPar['allowedList']
  29. slopeIndex, trtFlag_det,trtFlag_seg,segRegionCnt = objectPar['slopeIndex'],objectPar['trtFlag_det'],objectPar['trtFlag_seg'],objectPar['segRegionCnt']
  30. time0=time.time()
  31. if trtFlag_det:
  32. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  33. else:
  34. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
  35. # Stack
  36. img = np.stack(img, 0)
  37. # Convert
  38. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  39. img = np.ascontiguousarray(img)
  40. img = torch.from_numpy(img).to(device)
  41. img = img.half() if half else img.float() # uint8 to fp16/32
  42. time01=time.time()
  43. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  44. if segmodel:
  45. if trtFlag_seg:
  46. seg_pred,segstr = segtrtEval(segmodel,im0s[0],par=segPar)
  47. else:
  48. seg_pred,segstr = segmodel.eval(im0s[0] )
  49. segFlag=True
  50. else:
  51. seg_pred = None;segFlag=False;segstr='Not implemented'
  52. time1=time.time()
  53. if trtFlag_det:
  54. pred = yolov5Trtforward(model,img)
  55. else:
  56. pred = model(img,augment=False)[0]
  57. time2=time.time()
  58. datas = [[''], img, im0s, None,pred,seg_pred,10]
  59. ObjectPar={ 'object_config':allowedList, 'slopeIndex':slopeIndex ,'segmodel':segFlag,'segRegionCnt':segRegionCnt }
  60. p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos)
  61. time_info = 'letterbox:%.1f, seg:%.1f , infer:%.1f,%s, seginfo:%s'%( (time01-time0)*1000, (time1-time01)*1000 ,(time2-time1)*1000,timeOut , segstr )
  62. return p_result,time_info
  63. def AI_process_v2(im0s,model,segmodel,names,label_arraylist,rainbows,half=True,device=' cuda:0',conf_thres=0.25, iou_thres=0.45,allowedList=[0,1,2,3], font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ):
  64. #输入参数
  65. # im0s---原始图像列表
  66. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  67. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  68. # [im0s[0],im0,det_xywh,iframe]中,
  69. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  70. # det_xywh--检测结果,是一个列表。
  71. # 其中每一个元素表示一个目标构成如:[float(cls_c), xc,yc,w,h, float(conf_c)]
  72. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  73. # #strout---统计AI处理个环节的时间
  74. # Letterbox
  75. time0=time.time()
  76. #img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s]
  77. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  78. # Stack
  79. img = np.stack(img, 0)
  80. # Convert
  81. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  82. img = np.ascontiguousarray(img)
  83. img = torch.from_numpy(img).to(device)
  84. img = img.half() if half else img.float() # uint8 to fp16/32
  85. time01=time.time()
  86. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  87. if segmodel:
  88. seg_pred,segstr = segmodel.eval(im0s[0] )
  89. segFlag=True
  90. else:
  91. seg_pred = None;segFlag=False
  92. time1=time.time()
  93. pred = model(img,augment=False)
  94. time2=time.time()
  95. datas = [[''], img, im0s, None,pred,seg_pred,10]
  96. p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,object_config=allowedList,segmodel=segFlag,font=font,padInfos=padInfos)
  97. time_info = 'letterbox:%.1f, seg:%.1f , infer:%.1f,%s, seginfo:%s'%( (time01-time0)*1000, (time1-time01)*1000 ,(time2-time1)*1000,timeOut , segstr )
  98. return p_result,time_info
  99. def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=True,device=' cuda:0',conf_thres=0.25, iou_thres=0.45,allowedList=[0,1,2,3], font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,trtFlag_det=False):
  100. #输入参数
  101. # im0s---原始图像列表
  102. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  103. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  104. # [im0s[0],im0,det_xywh,iframe]中,
  105. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  106. # det_xywh--检测结果,是一个列表。
  107. # 其中每一个元素表示一个目标构成如:[float(cls_c), xc,yc,w,h, float(conf_c)]
  108. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  109. # #strout---统计AI处理个环节的时间
  110. # Letterbox
  111. time0=time.time()
  112. if trtFlag_det:
  113. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  114. else:
  115. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
  116. #img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s]
  117. # Stack
  118. img = np.stack(img, 0)
  119. # Convert
  120. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  121. img = np.ascontiguousarray(img)
  122. img = torch.from_numpy(img).to(device)
  123. img = img.half() if half else img.float() # uint8 to fp16/32
  124. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  125. if segmodel:
  126. seg_pred,segstr = segmodel.eval(im0s[0] )
  127. segFlag=True
  128. else:
  129. seg_pred = None;segFlag=False
  130. time1=time.time()
  131. pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0]
  132. time2=time.time()
  133. datas = [[''], img, im0s, None,pred,seg_pred,10]
  134. ObjectPar={ 'object_config':allowedList, 'slopeIndex':[] ,'segmodel':segFlag,'segRegionCnt':0 }
  135. p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos)
  136. #p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,object_config=allowedList,segmodel=segFlag,font=font,padInfos=padInfos)
  137. time_info = 'letterbox:%.1f, infer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 )
  138. return p_result,time_info+timeOut
  139. def main():
  140. ##预先设置的参数
  141. device_='1' ##选定模型,可选 cpu,'0','1'
  142. ##以下参数目前不可改
  143. Detweights = "weights/yolov5/class5/best_5classes.pt"
  144. seg_nclass = 2
  145. Segweights = "weights/BiSeNet/checkpoint.pth"
  146. conf_thres,iou_thres,classes= 0.25,0.45,5
  147. labelnames = "weights/yolov5/class5/labelnames.json"
  148. rainbows = [ [0,0,255],[0,255,0],[255,0,0],[255,0,255],[255,255,0],[255,129,0],[255,0,127],[127,255,0],[0,255,127],[0,127,255],[127,0,255],[255,127,255],[255,255,127],[127,255,255],[0,255,255],[255,127,255],[127,255,255], [0,127,0],[0,0,127],[0,255,255]]
  149. allowedList=[0,1,2,3]
  150. ##加载模型,准备好显示字符
  151. device = select_device(device_)
  152. names=get_labelnames(labelnames)
  153. label_arraylist = get_label_arrays(names,rainbows,outfontsize=40,fontpath="conf/platech.ttf")
  154. half = device.type != 'cpu' # half precision only supported on CUDA
  155. model = attempt_load(Detweights, map_location=device) # load FP32 model
  156. if half: model.half()
  157. segmodel = SegModel(nclass=seg_nclass,weights=Segweights,device=device)
  158. ##图像测试
  159. #url='images/examples/20220624_响水河_12300_1621.jpg'
  160. impth = 'images/examples/'
  161. outpth = 'images/results/'
  162. folders = os.listdir(impth)
  163. for i in range(len(folders)):
  164. imgpath = os.path.join(impth, folders[i])
  165. im0s=[cv2.imread(imgpath)]
  166. time00 = time.time()
  167. p_result,timeOut = AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,half,device,conf_thres, iou_thres,allowedList,fontSize=1.0)
  168. time11 = time.time()
  169. image_array = p_result[1]
  170. cv2.imwrite( os.path.join( outpth,folders[i] ) ,image_array )
  171. print('----process:%s'%(folders[i]), (time.time() - time11) * 1000)
  172. if __name__=="__main__":
  173. main()