You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

728 lines
36KB

  1. import cv2,os,time,json
  2. from models.experimental import attempt_load
  3. from segutils.segmodel import SegModel,get_largest_contours
  4. from segutils.trtUtils import segtrtEval,yolov5Trtforward,OcrTrtForward
  5. from segutils.trafficUtils import tracfficAccidentMixFunction
  6. from utils.torch_utils import select_device
  7. from utilsK.queRiver import get_labelnames,get_label_arrays,post_process_,img_pad,draw_painting_joint,detectDraw,getDetections,getDetectionsFromPreds
  8. from trackUtils.sort import moving_average_wang
  9. from utils.datasets import letterbox
  10. import numpy as np
  11. import torch
  12. import math
  13. from PIL import Image
  14. import torch.nn.functional as F
  15. from copy import deepcopy
  16. from scipy import interpolate
  17. import glob
  18. def get_images_videos(impth, imageFixs=['.jpg','.JPG','.PNG','.png'],videoFixs=['.MP4','.mp4','.avi']):
  19. imgpaths=[];###获取文件里所有的图像
  20. videopaths=[]###获取文件里所有的视频
  21. if os.path.isdir(impth):
  22. for postfix in imageFixs:
  23. imgpaths.extend(glob.glob('%s/*%s'%(impth,postfix )) )
  24. for postfix in videoFixs:
  25. videopaths.extend(glob.glob('%s/*%s'%(impth,postfix )) )
  26. else:
  27. postfix = os.path.splitext(impth)[-1]
  28. if postfix in imageFixs: imgpaths=[ impth ]
  29. if postfix in videoFixs: videopaths = [impth ]
  30. print('%s: test Images:%d , test videos:%d '%(impth, len(imgpaths), len(videopaths)))
  31. return imgpaths,videopaths
  32. def xywh2xyxy(box,iW=None,iH=None):
  33. xc,yc,w,h = box[0:4]
  34. x0 =max(0, xc-w/2.0)
  35. x1 =min(1, xc+w/2.0)
  36. y0=max(0, yc-h/2.0)
  37. y1=min(1,yc+h/2.0)
  38. if iW: x0,x1 = x0*iW,x1*iW
  39. if iH: y0,y1 = y0*iH,y1*iH
  40. return [x0,y0,x1,y1]
  41. def get_ms(t2,t1):
  42. return (t2-t1)*1000.0
  43. def get_postProcess_para(parfile):
  44. with open(parfile) as fp:
  45. par = json.load(fp)
  46. assert 'post_process' in par.keys(), ' parfile has not key word:post_process'
  47. parPost=par['post_process']
  48. return parPost["conf_thres"],parPost["iou_thres"],parPost["classes"],parPost["rainbows"]
  49. def get_postProcess_para_dic(parfile):
  50. with open(parfile) as fp:
  51. par = json.load(fp)
  52. parPost=par['post_process']
  53. return parPost
  54. def score_filter_byClass(pdetections,score_para_2nd):
  55. ret=[]
  56. for det in pdetections:
  57. score,cls = det[4],det[5]
  58. if int(cls) in score_para_2nd.keys():
  59. score_th = score_para_2nd[int(cls)]
  60. elif str(int(cls)) in score_para_2nd.keys():
  61. score_th = score_para_2nd[str(int(cls))]
  62. else:
  63. score_th = 0.7
  64. if score > score_th:
  65. ret.append(det)
  66. return ret
  67. def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False,'score_byClass':{x:0.1 for x in range(30)} }, font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,segPar={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True},mode='others',postPar=None):
  68. #输入参数
  69. # im0s---原始图像列表
  70. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  71. #
  72. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  73. # [im0s[0],im0,det_xywh,iframe]中,
  74. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  75. # det_xywh--检测结果,是一个列表。
  76. # 其中每一个元素表示一个目标构成如:[ xc,yc,w,h, float(conf_c),float(cls_c) ] ,2023.08.03修改输出格式
  77. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  78. # #strout---统计AI处理个环节的时间
  79. # Letterbox
  80. half,device,conf_thres,iou_thres,allowedList = objectPar['half'],objectPar['device'],objectPar['conf_thres'],objectPar['iou_thres'],objectPar['allowedList']
  81. trtFlag_det,trtFlag_seg,segRegionCnt = objectPar['trtFlag_det'],objectPar['trtFlag_seg'],objectPar['segRegionCnt']
  82. if 'ovlap_thres_crossCategory' in objectPar.keys(): ovlap_thres = objectPar['ovlap_thres_crossCategory']
  83. else: ovlap_thres = None
  84. if 'score_byClass' in objectPar.keys(): score_byClass = objectPar['score_byClass']
  85. else: score_byClass = None
  86. time0=time.time()
  87. if trtFlag_det:
  88. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  89. else:
  90. #print('####line72:',im0s[0][10:12,10:12,2])
  91. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
  92. #print('####line74:',img[0][10:12,10:12,2])
  93. # Stack
  94. img = np.stack(img, 0)
  95. # Convert
  96. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  97. img = np.ascontiguousarray(img)
  98. img = torch.from_numpy(img).to(device)
  99. img = img.half() if half else img.float() # uint8 to fp16/32
  100. img /= 255.0
  101. time01=time.time()
  102. if segmodel:
  103. seg_pred,segstr = segmodel.eval(im0s[0] )
  104. segFlag=True
  105. else:
  106. seg_pred = None;segFlag=False;segstr='Not implemented'
  107. time1=time.time()
  108. if trtFlag_det:
  109. pred = yolov5Trtforward(model,img)
  110. else:
  111. #print('####line96:',img[0,0,10:12,10:12])
  112. pred = model(img,augment=False)[0]
  113. time2=time.time()
  114. p_result, timeOut = getDetectionsFromPreds(pred,img,im0s[0],conf_thres=conf_thres,iou_thres=iou_thres,ovlap_thres=ovlap_thres,padInfos=padInfos)
  115. if score_byClass:
  116. p_result[2] = score_filter_byClass(p_result[2],score_byClass)
  117. print('-'*10,p_result[2])
  118. #if mode=='highWay3.0':
  119. #if segmodel:
  120. if segPar and segPar['mixFunction']['function']:
  121. mixFunction = segPar['mixFunction']['function'];H,W = im0s[0].shape[0:2]
  122. parMix = segPar['mixFunction']['pars'];#print('###line117:',parMix,p_result[2])
  123. parMix['imgSize'] = (W,H)
  124. #print(' -----------line110: ',p_result[2] ,'\n', seg_pred)
  125. p_result[2] , timeMixPost= mixFunction(p_result[2], seg_pred, pars=parMix )
  126. #print(' -----------line112: ',p_result[2] )
  127. p_result.append(seg_pred)
  128. else:
  129. timeMixPost=':0 ms'
  130. #print('#### line121: segstr:%s timeMixPost:%s timeOut:%s'%( segstr.strip(), timeMixPost,timeOut ))
  131. time_info = 'letterbox:%.1f, seg:%.1f , infer:%.1f,%s, seginfo:%s ,timeMixPost:%s '%( (time01-time0)*1000, (time1-time01)*1000 ,(time2-time1)*1000,timeOut , segstr.strip(),timeMixPost )
  132. #if mode=='highWay3.0':
  133. return p_result,time_info
  134. def default_mix(predlist,par):
  135. return predlist[0],''
  136. def AI_process_N(im0s,modelList,postProcess):
  137. #输入参数
  138. ## im0s---原始图像列表
  139. ## modelList--所有的模型
  140. # postProcess--字典{},包括后处理函数,及其参数
  141. #输出参数
  142. ##ret[0]--检测结果;
  143. ##ret[1]--时间信息
  144. #modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果
  145. modelRets=[ model.eval(im0s[0]) for model in modelList]
  146. timeInfos = [ x[1] for x in modelRets]
  147. timeInfos=''.join(timeInfos)
  148. timeInfos=timeInfos
  149. #postProcess['function']--后处理函数,输入的就是所有模型输出结果
  150. mixFunction =postProcess['function']
  151. predsList = [ modelRet[0] for modelRet in modelRets ]
  152. H,W = im0s[0].shape[0:2]
  153. postProcess['pars']['imgSize'] = (W,H)
  154. #ret就是混合处理后的结果
  155. ret = mixFunction( predsList, postProcess['pars'])
  156. return ret[0],timeInfos+ret[1]
  157. def getMaxScoreWords(detRets0):
  158. maxScore=-1;maxId=0
  159. for i,detRet in enumerate(detRets0):
  160. if detRet[4]>maxScore:
  161. maxId=i
  162. maxScore = detRet[4]
  163. return maxId
  164. def AI_process_C(im0s,modelList,postProcess):
  165. #函数定制的原因:
  166. ## 之前模型处理流是
  167. ## 图片---> 模型1-->result1;图片---> 模型2->result2;[result1,result2]--->后处理函数
  168. ## 本函数的处理流程是
  169. ## 图片---> 模型1-->result1;[图片,result1]---> 模型2->result2;[result1,result2]--->后处理函数
  170. ## 模型2的输入,是有模型1的输出决定的。如模型2是ocr模型,需要将模型1检测出来的船名抠图出来输入到模型2.
  171. ## 之前的模型流都是模型2是分割模型,输入就是原始图片,与模型1的输出无关。
  172. #输入参数
  173. ## im0s---原始图像列表
  174. ## modelList--所有的模型
  175. # postProcess--字典{},包括后处理函数,及其参数
  176. #输出参数
  177. ##ret[0]--检测结果;
  178. ##ret[1]--时间信息
  179. #modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果
  180. t0=time.time()
  181. detRets0 = modelList[0].eval(im0s[0])
  182. #detRets0=[[12, 46, 1127, 1544, 0.2340087890625, 2.0], [1884, 1248, 2992, 1485, 0.64208984375, 1.0]]
  183. detRets0 = detRets0[0]
  184. parsIn=postProcess['pars']
  185. _detRets0_obj = list(filter(lambda x: x[5] in parsIn['objs'], detRets0 ))
  186. _detRets0_others = list(filter(lambda x: x[5] not in parsIn['objs'], detRets0 ))
  187. _detRets0 = []
  188. if postProcess['name']=='channel2':
  189. if len(_detRets0_obj)>0:
  190. maxId=getMaxScoreWords(_detRets0_obj)
  191. _detRets0 = _detRets0_obj[maxId:maxId+1]
  192. else: _detRets0 = detRets0
  193. t1=time.time()
  194. imagePatches = [ im0s[0][int(x[1]):int(x[3] ) ,int(x[0]):int(x[2])] for x in _detRets0 ]
  195. detRets1 = [modelList[1].eval(patch) for patch in imagePatches]
  196. print('###line240:',detRets1)
  197. if postProcess['name']=='crackMeasurement':
  198. detRets1 = [x[0]*255 for x in detRets1]
  199. t2=time.time()
  200. mixFunction =postProcess['function']
  201. crackInfos = [mixFunction(patchMask,par=parsIn) for patchMask in detRets1]
  202. rets = [ _detRets0[i]+ crackInfos[i] for i in range(len(imagePatches)) ]
  203. t3=time.time()
  204. outInfos='total:%.1f (det:%.1f %d次segs:%.1f mixProcess:%.1f) '%( (t3-t0)*1000, (t1-t0)*1000, len(detRets1),(t2-t1)*1000, (t3-t2)*1000 )
  205. elif postProcess['name']=='channel2':
  206. H,W = im0s[0].shape[0:2];parsIn['imgSize'] = (W,H)
  207. mixFunction =postProcess['function']
  208. _detRets0_others = mixFunction([_detRets0_others], parsIn)
  209. ocrInfo='no ocr'
  210. if len(_detRets0_obj)>0:
  211. res_real = detRets1[0][0]
  212. res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
  213. #detRets1[0][0]="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),detRets1[0][0])))
  214. _detRets0_obj[maxId].append(res_real )
  215. _detRets0_obj = [_detRets0_obj[maxId]]##只输出有OCR的那个船名结果
  216. ocrInfo=detRets1[0][1]
  217. print( ' _detRets0_obj:{} _detRets0_others:{} '.format( _detRets0_obj, _detRets0_others ) )
  218. rets=_detRets0_obj+_detRets0_others
  219. t3=time.time()
  220. outInfos='total:%.1f ,where det:%.1f, ocr:%s'%( (t3-t0)*1000, (t1-t0)*1000, ocrInfo)
  221. #print('###line233:',detRets1,detRets0 )
  222. return rets,outInfos
  223. def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=True,device=' cuda:0',conf_thres=0.25, iou_thres=0.45,allowedList=[0,1,2,3], font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,trtFlag_det=False,SecNms=None):
  224. #输入参数
  225. # im0s---原始图像列表
  226. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  227. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  228. # [im0s[0],im0,det_xywh,iframe]中,
  229. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  230. # det_xywh--检测结果,是一个列表。
  231. # 其中每一个元素表示一个目标构成如:[ xc,yc,w,h, float(conf_c),float(cls_c)],#2023.08.03,修改输出格式
  232. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  233. # #strout---统计AI处理个环节的时间
  234. # Letterbox
  235. time0=time.time()
  236. if trtFlag_det:
  237. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  238. else:
  239. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
  240. #img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s]
  241. # Stack
  242. img = np.stack(img, 0)
  243. # Convert
  244. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  245. img = np.ascontiguousarray(img)
  246. img = torch.from_numpy(img).to(device)
  247. img = img.half() if half else img.float() # uint8 to fp16/32
  248. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  249. if segmodel:
  250. seg_pred,segstr = segmodel.eval(im0s[0] )
  251. segFlag=True
  252. else:
  253. seg_pred = None;segFlag=False
  254. time1=time.time()
  255. pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0]
  256. time2=time.time()
  257. datas = [[''], img, im0s, None,pred,seg_pred,10]
  258. ObjectPar={ 'object_config':allowedList, 'slopeIndex':[] ,'segmodel':segFlag,'segRegionCnt':0 }
  259. p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos,ovlap_thres=SecNms)
  260. #print('###line274:',p_result[2])
  261. #p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,object_config=allowedList,segmodel=segFlag,font=font,padInfos=padInfos)
  262. time_info = 'letterbox:%.1f, infer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 )
  263. return p_result,time_info+timeOut
  264. def AI_det_track( im0s_in,modelPar,processPar,sort_tracker,segPar=None):
  265. im0s,iframe=im0s_in[0],im0s_in[1]
  266. model = modelPar['det_Model']
  267. segmodel = modelPar['seg_Model']
  268. half,device,conf_thres, iou_thres,trtFlag_det = processPar['half'], processPar['device'], processPar['conf_thres'], processPar['iou_thres'],processPar['trtFlag_det']
  269. if 'score_byClass' in processPar.keys(): score_byClass = processPar['score_byClass']
  270. else: score_byClass = None
  271. iou2nd = processPar['iou2nd']
  272. time0=time.time()
  273. if trtFlag_det:
  274. img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img]
  275. else:
  276. img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
  277. img = np.stack(img, 0)
  278. # Convert
  279. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  280. img = np.ascontiguousarray(img)
  281. img = torch.from_numpy(img).to(device)
  282. img = img.half() if half else img.float() # uint8 to fp16/32
  283. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  284. seg_pred = None;segFlag=False
  285. time1=time.time()
  286. pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0]
  287. time2=time.time()
  288. #p_result,timeOut = getDetections(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos)
  289. p_result, timeOut = getDetectionsFromPreds(pred,img,im0s[0],conf_thres=conf_thres,iou_thres=iou_thres,ovlap_thres=iou2nd,padInfos=padInfos)
  290. if score_byClass:
  291. p_result[2] = score_filter_byClass(p_result[2],score_byClass)
  292. if segmodel:
  293. seg_pred,segstr = segmodel.eval(im0s[0] )
  294. segFlag=True
  295. else:
  296. seg_pred = None;segFlag=False;segstr='No segmodel'
  297. if segPar and segPar['mixFunction']['function']:
  298. mixFunction = segPar['mixFunction']['function']
  299. H,W = im0s[0].shape[0:2]
  300. parMix = segPar['mixFunction']['pars'];#print('###line117:',parMix,p_result[2])
  301. parMix['imgSize'] = (W,H)
  302. p_result[2],timeInfos_post = mixFunction(p_result[2], seg_pred, pars=parMix )
  303. timeInfos_seg_post = 'segInfer:%s ,postMixProcess:%s'%( segstr, timeInfos_post )
  304. else:
  305. timeInfos_seg_post = ' '
  306. '''
  307. if segmodel:
  308. timeS1=time.time()
  309. #seg_pred,segstr = segtrtEval(segmodel,im0s[0],par=segPar) if segPar['trtFlag_seg'] else segmodel.eval(im0s[0] )
  310. seg_pred,segstr = segmodel.eval(im0s[0] )
  311. timeS2=time.time()
  312. mixFunction = segPar['mixFunction']['function']
  313. p_result[2],timeInfos_post = mixFunction(p_result[2], seg_pred, pars=segPar['mixFunction']['pars'] )
  314. timeInfos_seg_post = 'segInfer:%.1f ,postProcess:%s'%( (timeS2-timeS1)*1000, timeInfos_post )
  315. else:
  316. timeInfos_seg_post = ' '
  317. #print('######line341:',seg_pred.shape,np.max(seg_pred),np.min(seg_pred) , len(p_result[2]) )
  318. '''
  319. time_info = 'letterbox:%.1f, detinfer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 )
  320. if sort_tracker:
  321. #在这里增加设置调用追踪器的频率
  322. #..................USE TRACK FUNCTION....................
  323. #pass an empty array to sort
  324. dets_to_sort = np.empty((0,7), dtype=np.float32)
  325. # NOTE: We send in detected object class too
  326. #for detclass,x1,y1,x2,y2,conf in p_result[2]:
  327. for x1,y1,x2,y2,conf, detclass in p_result[2]:
  328. #print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
  329. dets_to_sort = np.vstack((dets_to_sort,
  330. np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
  331. # Run SORT
  332. tracked_dets = deepcopy(sort_tracker.update(dets_to_sort) )
  333. tracks =sort_tracker.getTrackers()
  334. p_result.append(tracked_dets) ###index=4
  335. p_result.append(tracks) ###index=5
  336. return p_result,time_info+timeOut+timeInfos_seg_post
  337. def AI_det_track_batch(imgarray_list, iframe_list ,modelPar,processPar,sort_tracker,trackPar,segPar=None):
  338. '''
  339. 输入:
  340. imgarray_list--图像列表
  341. iframe_list -- 帧号列表
  342. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  343. processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  344. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  345. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  346. segPar--None,分割模型相关参数。如果用不到,则为None
  347. 输入:retResults,timeInfos
  348. retResults:list
  349. retResults[0]--imgarray_list
  350. retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
  351. retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ x0 ,y0 ,x1 ,y1 ,conf, cls ,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。2023.08.03,修改输出格式
  352. '''
  353. det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
  354. trackers_dic={}
  355. index_list = list(range( 0, len(iframe_list) ,det_cnt ));
  356. if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
  357. index_list.append( len(iframe_list) - 1 )
  358. if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
  359. retResults = []
  360. p_result,timeOut = AI_det_track( [ [imgarray_list[0]] ,iframe_list[0] ],modelPar,processPar,None,segPar )
  361. ##下面4行内容只是为了保持格式一致
  362. detArray = np.array(p_result[2])
  363. #print('##line371:',detArray)
  364. if len(p_result[2])==0:res=[]
  365. else:
  366. cnt = detArray.shape[0];trackIds=np.zeros((cnt,1));iframes = np.zeros((cnt,1)) + iframe_list[0]
  367. #detArray = np.hstack( (detArray[:,1:5], detArray[:,5:6] ,detArray[:,0:1],iframes, trackIds ) )
  368. detArray = np.hstack( (detArray[:,0:4], detArray[:,4:6] ,iframes, trackIds ) ) ##2023.08.03 修改输入格式
  369. res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in detArray ]
  370. retResults=[imgarray_list,detArray,res ]
  371. #print('##line380:',retResults[2])
  372. return retResults,timeOut
  373. else:
  374. t0 = time.time()
  375. timeInfos_track=''
  376. for iframe_index, index_frame in enumerate(index_list):
  377. p_result,timeOut = AI_det_track( [ [imgarray_list[index_frame]] ,iframe_list[index_frame] ],modelPar,processPar,sort_tracker,segPar )
  378. timeInfos_track='%s:%s'%(timeInfos_track,timeOut)
  379. for tracker in p_result[5]:
  380. trackers_dic[tracker.id]=deepcopy(tracker)
  381. t1 = time.time()
  382. track_det_result = np.empty((0,8))
  383. for trackId in trackers_dic.keys():
  384. tracker = trackers_dic[trackId]
  385. bbox_history = np.array(tracker.bbox_history)
  386. if len(bbox_history)<2: continue
  387. ###把(x0,y0,x1,y1)转换成(xc,yc,w,h)
  388. xcs_ycs = (bbox_history[:,0:2] + bbox_history[:,2:4] )/2
  389. whs = bbox_history[:,2:4] - bbox_history[:,0:2]
  390. bbox_history[:,0:2] = xcs_ycs;bbox_history[:,2:4] = whs;
  391. arrays_box = bbox_history[:,0:7].transpose();frames=bbox_history[:,6]
  392. #frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
  393. #frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
  394. ##如果要最好化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
  395. #inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
  396. ##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
  397. inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
  398. new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
  399. f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
  400. move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
  401. for im in range(4):
  402. interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
  403. cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
  404. interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
  405. track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
  406. #print('#####line116:',trackId,frame_min,frame_max,'----------',interpolation_x0s.shape,track_det_result.shape ,'-----')
  407. ##将[xc,yc,w,h]转为[x0,y0,x1,y1]
  408. x0s = track_det_result[:,0] - track_det_result[:,2]/2 ; x1s = track_det_result[:,0] + track_det_result[:,2]/2
  409. y0s = track_det_result[:,1] - track_det_result[:,3]/2 ; y1s = track_det_result[:,1] + track_det_result[:,3]/2
  410. track_det_result[:,0] = x0s; track_det_result[:,1] = y0s;
  411. track_det_result[:,2] = x1s; track_det_result[:,3] = y1s;
  412. detResults=[]
  413. for iiframe in iframe_list:
  414. boxes_oneFrame = track_det_result[ track_det_result[:,6]==iiframe ]
  415. res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in boxes_oneFrame ]
  416. #[ x0 ,y0 ,x1 ,y1 ,conf,cls,ifrmae,trackId ]
  417. #[ifrmae, x0 ,y0 ,x1 ,y1 ,conf,cls,trackId ]
  418. detResults.append( res )
  419. retResults=[imgarray_list,track_det_result,detResults ]
  420. t2 = time.time()
  421. timeInfos = 'detTrack:%.1f TrackPost:%.1f, %s'%(get_ms(t1,t0),get_ms(t2,t1), timeInfos_track )
  422. return retResults,timeInfos
  423. def AI_det_track_N( im0s_in,modelList,postProcess,sort_tracker):
  424. im0s,iframe=im0s_in[0],im0s_in[1]
  425. dets = AI_process_N(im0s,modelList,postProcess)
  426. p_result=[[],[],dets[0],[] ]
  427. if sort_tracker:
  428. #在这里增加设置调用追踪器的频率
  429. #..................USE TRACK FUNCTION....................
  430. #pass an empty array to sort
  431. dets_to_sort = np.empty((0,7), dtype=np.float32)
  432. # NOTE: We send in detected object class too
  433. #for detclass,x1,y1,x2,y2,conf in p_result[2]:
  434. for x1,y1,x2,y2,conf, detclass in p_result[2]:
  435. #print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
  436. dets_to_sort = np.vstack((dets_to_sort,
  437. np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
  438. # Run SORT
  439. tracked_dets = deepcopy(sort_tracker.update(dets_to_sort) )
  440. tracks =sort_tracker.getTrackers()
  441. p_result.append(tracked_dets) ###index=4
  442. p_result.append(tracks) ###index=5
  443. return p_result,dets[1]
  444. def get_tracker_cls(boxes,scId=4,clsId=5):
  445. #正常来说一各跟踪链上是一个类别,但是有时目标框检测错误,导致有的跟踪链上有多个类别
  446. #为此,根据跟踪链上每一个类别对应的所有框的置信度之和,作为这个跟踪链上目标的类别
  447. #输入boxes--跟踪是保留的box_history,[[xc,yc,width,height,score,class,iframe],[...],[...]]
  448. ## scId=4,score所在的序号; clsId=5;类别所在的序号
  449. #输出类别
  450. ##这个跟踪链上目标的类别
  451. ids = list(set(boxes[:,clsId].tolist()))
  452. scores = [np.sum( boxes[:,scId] [ boxes[:,clsId]==x ] ) for x in ids]
  453. maxScoreId = scores.index(np.max(scores))
  454. return int(ids[maxScoreId])
  455. def AI_det_track_batch_N(imgarray_list, iframe_list ,modelList,postProcess,sort_tracker,trackPar):
  456. '''
  457. 输入:
  458. imgarray_list--图像列表
  459. iframe_list -- 帧号列表
  460. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  461. processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  462. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  463. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  464. segPar--None,分割模型相关参数。如果用不到,则为None
  465. 输入:retResults,timeInfos
  466. retResults:list
  467. retResults[0]--imgarray_list
  468. retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
  469. retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ x0 ,y0 ,x1 ,y1 ,conf, cls ,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。2023.08.03,修改输出格式
  470. '''
  471. det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
  472. trackers_dic={}
  473. index_list = list(range( 0, len(iframe_list) ,det_cnt ));
  474. if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
  475. index_list.append( len(iframe_list) - 1 )
  476. if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
  477. retResults = []
  478. p_result,timeOut = AI_det_track_N( [ [imgarray_list[0]] ,iframe_list[0] ],modelList,postProcess,None )
  479. ##下面4行内容只是为了保持格式一致
  480. detArray = np.array(p_result[2])
  481. if len(p_result[2])==0:res=[]
  482. else:
  483. cnt = detArray.shape[0];trackIds=np.zeros((cnt,1));iframes = np.zeros((cnt,1)) + iframe_list[0]
  484. #detArray = np.hstack( (detArray[:,1:5], detArray[:,5:6] ,detArray[:,0:1],iframes, trackIds ) )
  485. detArray = np.hstack( (detArray[:,0:4], detArray[:,4:6] ,iframes, trackIds ) ) ##2023.08.03 修改输入格式
  486. res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in detArray ]
  487. retResults=[imgarray_list,detArray,res ]
  488. #print('##line380:',retResults[2])
  489. return retResults,timeOut
  490. else:
  491. t0 = time.time()
  492. timeInfos_track=''
  493. for iframe_index, index_frame in enumerate(index_list):
  494. p_result,timeOut = AI_det_track_N( [ [imgarray_list[index_frame]] ,iframe_list[index_frame] ],modelList,postProcess,sort_tracker )
  495. timeInfos_track='%s:%s'%(timeInfos_track,timeOut)
  496. for tracker in p_result[5]:
  497. trackers_dic[tracker.id]=deepcopy(tracker)
  498. t1 = time.time()
  499. track_det_result = np.empty((0,8))
  500. for trackId in trackers_dic.keys():
  501. tracker = trackers_dic[trackId]
  502. bbox_history = np.array(tracker.bbox_history).copy()
  503. if len(bbox_history)<2: continue
  504. ###把(x0,y0,x1,y1)转换成(xc,yc,w,h)
  505. xcs_ycs = (bbox_history[:,0:2] + bbox_history[:,2:4] )/2
  506. whs = bbox_history[:,2:4] - bbox_history[:,0:2]
  507. bbox_history[:,0:2] = xcs_ycs;bbox_history[:,2:4] = whs;
  508. #2023.11.17添加的。目的是修正跟踪链上所有的框的类别一样
  509. chainClsId = get_tracker_cls(bbox_history,scId=4,clsId=5)
  510. bbox_history[:,5] = chainClsId
  511. arrays_box = bbox_history[:,0:7].transpose();frames=bbox_history[:,6]
  512. #frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
  513. #frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
  514. ##如果要最好化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
  515. #inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
  516. ##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
  517. inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
  518. new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
  519. f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
  520. move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
  521. for im in range(4):
  522. interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
  523. cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
  524. interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
  525. track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
  526. #print('#####line116:',trackId,'----------',interpolation_x0s.shape,track_det_result.shape,bbox_history ,'-----')
  527. ##将[xc,yc,w,h]转为[x0,y0,x1,y1]
  528. x0s = track_det_result[:,0] - track_det_result[:,2]/2 ; x1s = track_det_result[:,0] + track_det_result[:,2]/2
  529. y0s = track_det_result[:,1] - track_det_result[:,3]/2 ; y1s = track_det_result[:,1] + track_det_result[:,3]/2
  530. track_det_result[:,0] = x0s; track_det_result[:,1] = y0s;
  531. track_det_result[:,2] = x1s; track_det_result[:,3] = y1s;
  532. detResults=[]
  533. for iiframe in iframe_list:
  534. boxes_oneFrame = track_det_result[ track_det_result[:,6]==iiframe ]
  535. res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in boxes_oneFrame ]
  536. #[ x0 ,y0 ,x1 ,y1 ,conf,cls,ifrmae,trackId ]
  537. #[ifrmae, x0 ,y0 ,x1 ,y1 ,conf,cls,trackId ]
  538. detResults.append( res )
  539. retResults=[imgarray_list,track_det_result,detResults ]
  540. t2 = time.time()
  541. timeInfos = 'detTrack:%.1f TrackPost:%.1f, %s'%(get_ms(t1,t0),get_ms(t2,t1), timeInfos_track )
  542. return retResults,timeInfos
  543. def ocr_process(pars):
  544. img_patch,engine,context,converter,AlignCollate_normal,device=pars[0:6]
  545. time1 = time.time()
  546. img_tensor = AlignCollate_normal([ Image.fromarray(img_patch,'L') ])
  547. img_input = img_tensor.to('cuda:0')
  548. time2 = time.time()
  549. preds,trtstr=OcrTrtForward(engine,[img_input],context)
  550. time3 = time.time()
  551. batch_size = preds.size(0)
  552. preds_size = torch.IntTensor([preds.size(1)] * batch_size)
  553. ######## filter ignore_char, rebalance
  554. preds_prob = F.softmax(preds, dim=2)
  555. preds_prob = preds_prob.cpu().detach().numpy()
  556. pred_norm = preds_prob.sum(axis=2)
  557. preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
  558. preds_prob = torch.from_numpy(preds_prob).float().to(device)
  559. _, preds_index = preds_prob.max(2)
  560. preds_index = preds_index.view(-1)
  561. time4 = time.time()
  562. preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data)
  563. time5 = time.time()
  564. info_str= ('pre-process:%.2f TRTforward:%.2f (%s) postProcess:%2.f decoder:%.2f, Total:%.2f , pred:%s'%(get_ms(time2,time1 ),get_ms(time3,time2 ),trtstr, get_ms(time4,time3 ), get_ms(time5,time4 ), get_ms(time5,time1 ), preds_str ) )
  565. return preds_str,info_str
  566. def main():
  567. ##预先设置的参数
  568. device_='1' ##选定模型,可选 cpu,'0','1'
  569. ##以下参数目前不可改
  570. Detweights = "weights/yolov5/class5/best_5classes.pt"
  571. seg_nclass = 2
  572. Segweights = "weights/BiSeNet/checkpoint.pth"
  573. conf_thres,iou_thres,classes= 0.25,0.45,5
  574. labelnames = "weights/yolov5/class5/labelnames.json"
  575. rainbows = [ [0,0,255],[0,255,0],[255,0,0],[255,0,255],[255,255,0],[255,129,0],[255,0,127],[127,255,0],[0,255,127],[0,127,255],[127,0,255],[255,127,255],[255,255,127],[127,255,255],[0,255,255],[255,127,255],[127,255,255], [0,127,0],[0,0,127],[0,255,255]]
  576. allowedList=[0,1,2,3]
  577. ##加载模型,准备好显示字符
  578. device = select_device(device_)
  579. names=get_labelnames(labelnames)
  580. label_arraylist = get_label_arrays(names,rainbows,outfontsize=40,fontpath="conf/platech.ttf")
  581. half = device.type != 'cpu' # half precision only supported on CUDA
  582. model = attempt_load(Detweights, map_location=device) # load FP32 model
  583. if half: model.half()
  584. segmodel = SegModel(nclass=seg_nclass,weights=Segweights,device=device)
  585. ##图像测试
  586. #url='images/examples/20220624_响水河_12300_1621.jpg'
  587. impth = 'images/examples/'
  588. outpth = 'images/results/'
  589. folders = os.listdir(impth)
  590. for i in range(len(folders)):
  591. imgpath = os.path.join(impth, folders[i])
  592. im0s=[cv2.imread(imgpath)]
  593. time00 = time.time()
  594. p_result,timeOut = AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,half,device,conf_thres, iou_thres,allowedList,fontSize=1.0)
  595. time11 = time.time()
  596. image_array = p_result[1]
  597. cv2.imwrite( os.path.join( outpth,folders[i] ) ,image_array )
  598. #print('----process:%s'%(folders[i]), (time.time() - time11) * 1000)
  599. if __name__=="__main__":
  600. main()