選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

527 行
22KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. sys.path.extend(['../AIlib2/utils'])
  18. #sys.path.extend(['../AIlib2/utils'])
  19. from plots import draw_painting_joint
  20. from copy import deepcopy
  21. from scipy import interpolate
  22. def obbTohbb(obb):
  23. obbarray=np.array(obb)
  24. x0=np.min(obbarray[:,0])
  25. x1=np.max(obbarray[:,0])
  26. y0=np.min(obbarray[:,1])
  27. y1=np.max(obbarray[:,1])
  28. return [x0,y0,x1,y1]
  29. def trt_version():
  30. return trt.__version__
  31. def torch_device_from_trt(device):
  32. if device == trt.TensorLocation.DEVICE:
  33. return torch.device("cuda")
  34. elif device == trt.TensorLocation.HOST:
  35. return torch.device("cpu")
  36. else:
  37. return TypeError("%s is not supported by torch" % device)
  38. def torch_dtype_from_trt(dtype):
  39. if dtype == trt.int8:
  40. return torch.int8
  41. elif trt_version() >= '7.0' and dtype == trt.bool:
  42. return torch.bool
  43. elif dtype == trt.int32:
  44. return torch.int32
  45. elif dtype == trt.float16:
  46. return torch.float16
  47. elif dtype == trt.float32:
  48. return torch.float32
  49. else:
  50. raise TypeError("%s is not supported by torch" % dtype)
  51. def segTrtForward(engine,inputs,contextFlag=False):
  52. if not contextFlag: context = engine.create_execution_context()
  53. else: context=contextFlag
  54. #with engine.create_execution_context() as context:
  55. #input_names=['images'];output_names=['output']
  56. namess=[ engine.get_binding_name(index) for index in range(engine.num_bindings) ]
  57. input_names = [namess[0]];output_names=namess[1:]
  58. batch_size = inputs[0].shape[0]
  59. bindings = [None] * (len(input_names) + len(output_names))
  60. # 创建输出tensor,并分配内存
  61. outputs = [None] * len(output_names)
  62. for i, output_name in enumerate(output_names):
  63. idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
  64. dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
  65. shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
  66. device = torch_device_from_trt(engine.get_location(idx))
  67. output = torch.empty(size=shape, dtype=dtype, device=device)
  68. #print('&'*10,'batch_size:',batch_size , 'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
  69. outputs[i] = output
  70. #print('###line65:',output_name,i,idx,dtype,shape)
  71. bindings[idx] = output.data_ptr()#绑定输出数据指针
  72. for i, input_name in enumerate(input_names):
  73. idx =engine.get_binding_index(input_name)
  74. bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
  75. #print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
  76. context.execute_v2(bindings) # 执行推理
  77. if len(outputs) == 1:
  78. outputs = outputs[0]
  79. return outputs[0]
  80. else:
  81. return outputs
  82. def apply_mask(image, mask, alpha=0.5):
  83. """Apply the given mask to the image.
  84. """
  85. color = np.random.rand(3)
  86. for c in range(3):
  87. image[:, :, c] = np.where(mask == 1,
  88. image[:, :, c] *
  89. (1 - alpha) + alpha * color[c] * 255,
  90. image[:, :, c])
  91. return image
  92. if not os.path.exists('output'):
  93. os.mkdir('output')
  94. saveDir = 'output'
  95. def get_ms(t2,t1):
  96. return (t2-t1)*1000.0
  97. def draw_painting_joint_2(box,img,label_array,score=0.5,color=None,font={ 'line_thickness':None,'boxLine_thickness':None, 'fontSize':None},socre_location="leftTop"):
  98. ###先把中文类别字体赋值到img中
  99. lh, lw, lc = label_array.shape
  100. imh, imw, imc = img.shape
  101. if socre_location=='leftTop':
  102. x0 , y1 = box[0][0],box[0][1]
  103. elif socre_location=='leftBottom':
  104. x0,y1=box[3][0],box[3][1]
  105. else:
  106. print('plot.py line217 ,label_location:%s not implemented '%( socre_location ))
  107. sys.exit(0)
  108. x1 , y0 = x0 + lw , y1 - lh
  109. if y0<0:y0=0;y1=y0+lh
  110. if y1>imh: y1=imh;y0=y1-lh
  111. if x0<0:x0=0;x1=x0+lw
  112. if x1>imw:x1=imw;x0=x1-lw
  113. img[y0:y1,x0:x1,:] = label_array
  114. pts_cls=[(x0,y0),(x1,y1) ]
  115. #把四边形的框画上
  116. box_tl= font['boxLine_thickness'] or round(0.002 * (imh + imw) / 2) + 1
  117. cv2.polylines(img, [box], True,color , box_tl)
  118. ####把英文字符score画到类别旁边
  119. tl = font['line_thickness'] or round(0.002*(imh+imw)/2)+1#line/font thickness
  120. label = ' %.2f'%(score)
  121. tf = max(tl , 1) # font thickness
  122. fontScale = font['fontSize'] or tl * 0.33
  123. t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
  124. #if socre_location=='leftTop':
  125. p1,p2= (pts_cls[1][0], pts_cls[0][1]),(pts_cls[1][0]+t_size[0],pts_cls[1][1])
  126. cv2.rectangle(img, p1 , p2, color, -1, cv2.LINE_AA)
  127. p3 = pts_cls[1][0],pts_cls[1][1]-(lh-t_size[1])//2
  128. cv2.putText(img, label,p3, 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  129. return img
  130. def OBB_infer(model,ori_image,par):
  131. '''
  132. 输出:[img_origin,ori_image, out_box,9999],infos
  133. img_origin---原图
  134. ori_image---画框图
  135. out_box---检测目标框
  136. ---格式如下[ [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score, cls ], [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score ,cls ],........ ],etc
  137. ---[ [ [(1159, 297), [922, 615], [817, 591], [1054, 272]], 0.865605354309082,14],
  138. [[(1330, 0), [1289, 58], [1228, 50], [1270, 0]], 0.3928087651729584,14] #2023.08.03,修改输出格式
  139. ]
  140. 9999---无意义,备用
  141. '''
  142. t1 = time.time()
  143. #ori_image = cv2.imread(impth+folders[i])
  144. t2 = time.time()
  145. img= cv2.resize(ori_image, (par['model_size']))
  146. img_origin = ori_image.copy()
  147. t3 = time.time()
  148. transf2 = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=par['mean'], std=par['std'])])
  149. img_tensor = transf2(img)
  150. img_tensor1=img_tensor.unsqueeze(0) #转成了需要的tensor格式,中心归一化及颜色通道匹配上
  151. t4=time.time()
  152. #print('###line170: resize:%.1f ToTensor-Normal-Destd:%.1f '%(get_ms(t3,t2),get_ms(t4,t3) ), img_origin.shape,img_tensor1.size() )
  153. #img_tensor1= img_tensor1.to( par['device']) #布置到cuda上
  154. img_tensor1 = img_tensor1.cuda()
  155. t5 =time.time()
  156. img_tensor1 = img_tensor1.half() if par['half'] else img_tensor1
  157. if par['saveType']=='trt':
  158. preds= segTrtForward(model,[img_tensor1])
  159. preds=[x[0] for x in preds ]
  160. pr_decs={}
  161. heads=list(par['heads'].keys())
  162. pr_decs={ heads[i]: preds[i] for i in range(len(heads)) }
  163. elif par['saveType']=='pth':
  164. with torch.no_grad(): # no Back propagation
  165. pr_decs = model(img_tensor1) # 前向传播一部分
  166. elif par['saveType']=='onnx':
  167. img=img_tensor1.cpu().numpy().astype(np.float32)
  168. preds = model['sess'].run(None, {model['input_name']: img})
  169. pr_decs={}
  170. heads=list(par['heads'].keys())
  171. pr_decs={ heads[i]: torch.from_numpy(preds[i]) for i in range(len(heads)) }
  172. t6 = time.time()
  173. category=par['labelnames']
  174. #torch.cuda.synchronize(par['device']) # 时间异步变同步
  175. decoded_pts = []
  176. decoded_scores = []
  177. predictions = par['decoder'].ctdet_decode(pr_decs) # 解码
  178. t6_1=time.time()
  179. pts0, scores0 = func_utils.decode_prediction(predictions, category,par['model_size'], par['down_ratio'],ori_image) # 改3
  180. decoded_pts.append(pts0)
  181. decoded_scores.append(scores0)
  182. t7 = time.time()
  183. # nms
  184. results = {cat: [] for cat in category}
  185. # '''这里啊
  186. for cat in category:
  187. if cat == 'background':
  188. continue
  189. pts_cat = []
  190. scores_cat = []
  191. for pts0, scores0 in zip(decoded_pts, decoded_scores):
  192. pts_cat.extend(pts0[cat])
  193. scores_cat.extend(scores0[cat])
  194. pts_cat = np.asarray(pts_cat, np.float32)
  195. scores_cat = np.asarray(scores_cat, np.float32)
  196. if pts_cat.shape[0]:
  197. nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat)
  198. results[cat].extend(nms_results)
  199. t8 = time.time()
  200. height, width, _ = ori_image.shape
  201. # nms
  202. out_box=[]
  203. for cat in category:
  204. if cat == 'background':
  205. continue
  206. result = results[cat]
  207. for pred in result:
  208. score = pred[-1]
  209. cls = category.index(cat)
  210. boxF=[ max(int(x),0) for x in pred[0:8]]
  211. #box_out=[ cls,[ ( boxF[0], boxF[1]),([boxF[2], boxF[3]]), ([boxF[4], boxF[5]]), ([boxF[6], boxF[7]]) ],score]
  212. box_out=[ [ ( boxF[0], boxF[1]),([boxF[2], boxF[3]]), ([boxF[4], boxF[5]]), ([boxF[6], boxF[7]]) ],score,cls]
  213. '''
  214. if par['drawBox']:
  215. tl = np.asarray([pred[0], pred[1]], np.float32)
  216. tr = np.asarray([pred[2], pred[3]], np.float32)
  217. br = np.asarray([pred[4], pred[5]], np.float32)
  218. bl = np.asarray([pred[6], pred[7]], np.float32)
  219. box = np.asarray([tl, tr, br, bl], np.int32)
  220. bgColor=par['rainbows'][cls%len( par['rainbows'])]
  221. label_array =par['label_array'][cls]
  222. font=par['digitWordFont']
  223. label_location=font['label_location']
  224. ori_image=draw_painting_joint(box,ori_image,label_array,score=score,color=bgColor,font=font,socre_location=label_location)
  225. '''
  226. out_box.append(box_out)
  227. t9 = time.time()
  228. t10 = time.time()
  229. infos=' preProcess:%.1f ToGPU:%.1f infer:%.1f decoder:%.1f, corr_change:%.1f nms:%.1f postProcess:%.1f, total process:%.1f '%( get_ms(t4,t2), get_ms(t5,t4),get_ms(t6,t5),get_ms(t6_1,t6),get_ms(t7,t6_1),get_ms(t8,t7) ,get_ms(t9,t8) ,get_ms(t9,t2) )
  230. #'preProcess:%.1f ToGPU:%.1f infer:%.1f decoder:%.1f, corr_change:%.1f nms:%.1f postProcess:%.1f, total process:%.1f '%
  231. #( get_ms(t4,t2), get_ms(t5,t4),get_ms(t6,t5),get_ms(t6_1,t6),get_ms(t7,t6_1), get_ms(t8,t7) ,get_ms(t9,t8) , get_ms(t9,t2) )
  232. if len(out_box) > 0:
  233. ret_4pts = np.array([ x[0] for x in out_box ] )
  234. ret_4pts = rectangle_quadrangle_batch (ret_4pts)
  235. cnt = len(out_box )
  236. for ii in range(cnt):
  237. out_box[ii][0] = ret_4pts[ii]
  238. return [img_origin,ori_image, out_box,9999],infos
  239. def draw_obb(preds,ori_image,par):
  240. for pred in preds:
  241. box = np.asarray(pred[0][0:4],np.int32)
  242. cls = int(pred[2]);score = pred[1]
  243. bgColor=par['rainbows'][cls%len( par['rainbows'])]
  244. label_array =par['label_array'][cls]
  245. font=par['digitWordFont']
  246. label_location=font['label_location']
  247. #print('###line285:',box,cls,score)
  248. ori_image=draw_painting_joint(box,ori_image,label_array,score=score,color=bgColor,font=font,socre_location=label_location)
  249. #cv2.imwrite( 'test.jpg',ori_image )
  250. return ori_image
  251. def OBB_tracker(sort_tracker,hbbs,obbs,iframe):
  252. #sort_tracker--跟踪器
  253. #hbbs--目标的水平框[x0,y0,x1,y1]
  254. #obbs--目标的倾斜框box = np.asarray([tl, tr, br, bl], np.int32)
  255. #返回值:sort_tracker,跟踪器
  256. dets_to_sort = np.empty((0,7), dtype=np.float32)
  257. # NOTE: We send in detected object class too
  258. for x1,y1,x2,y2,conf, detclass in hbbs:
  259. #print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
  260. dets_to_sort = np.vstack((dets_to_sort,
  261. np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
  262. # Run SORT
  263. tracked_dets = deepcopy(sort_tracker.update(dets_to_sort,obbs) )
  264. return tracked_dets
  265. def rectangle_quadrangle(vectors):
  266. ##输入的是四个点偏离中心点的向量,(M,4,2)
  267. ##输出:vectors--修正后的向量(M,4,2)
  268. # wh_thetas--矩形的向量 (M,1,3)[w,h,theta]
  269. distans = np.sqrt(np.sum(vectors**2,axis=2))#(M,4)
  270. mean_dis = np.mean( distans,axis=1 ).reshape(-1,1) #(M,1)
  271. mean_dis = np.tile(mean_dis,(1,4) ) #(M,4)
  272. scale_factors = mean_dis/distans #(M,4)
  273. scale_factors = np.expand_dims(scale_factors, axis=2 ) #(M,4,1)
  274. scale_factors = np.tile(scale_factors, (1,1,2) ) #M(M,4,2)
  275. vectors = vectors*scale_factors
  276. vectors = vectors.astype(np.int32)
  277. cnt = vectors.shape[0]
  278. boxes = [ cv2.minAreaRect( vectors[i] ) for i in range(cnt) ]
  279. wh_thetas = [[x[1][0],x[1][1],x[2] ] for x in boxes]#(M,3),[w,h,theta]
  280. wh_thetas = np.array(wh_thetas)##(M,3)
  281. return vectors,wh_thetas
  282. def adjust_pts_orders(vectors):
  283. #输入一系列(M,4,2)点
  284. #输入原定框顺序的(M,4,2)
  285. #前后两个四边形框一次判定,调整下一个四边形框内四个点的顺序,保证与上一个一致。
  286. cnt = vectors.shape[0]
  287. if cnt<=1: return vectors
  288. else:
  289. out=[];out.append(vectors[0])
  290. for i in range(1,cnt):
  291. pts1 = out[-1]
  292. pts2 = vectors[i]
  293. diss,min_dis,min_index,pts2_adjust = pts_setDistance(pts1,pts2)
  294. #if min_index!=0: print(min_index,pts1,pts2 )
  295. out.append(pts2_adjust)
  296. out = np.array(out)
  297. #if out[4,0,0]==53 and out[4,0,1]==10:
  298. #print('#line339:',out.shape ,' ','in ', vectors.reshape(-1,8) , ' out :',out.reshape(-1,8))
  299. return out
  300. def pts_setDistance(pts1,pts2):
  301. #输入是两个四边形的坐标(4,2),pts1保持不变,pts2逐个调整顺序,找到与pts2最匹配的四个点。
  302. #输出pts2 原始的距离,最匹配点的距离,最匹配的点的序号
  303. pts3=np.vstack((pts2,pts2))
  304. diss =[np.sum((pts1-pts3[i:i+4])**2) for i in range(4)]
  305. min_dis = min(diss)
  306. min_index = diss.index(min_dis)
  307. return diss[0],min_dis,min_index,pts3[min_index:min_index+4]
  308. def obbPointsConvert(obbs):
  309. obbArray = np.array(obbs)#( M,4,2)
  310. #计算中心点
  311. middlePts = np.mean( obbArray,axis=1 )##中心点(M,2)
  312. middlePts = np.expand_dims(middlePts,axis=1)#(M,1,2)
  313. #将中心点扩展成(M,4,2)
  314. vectors = np.tile(middlePts,(1,4,1))#(M,4,2)
  315. #计算偏移向量
  316. vectors = obbArray - vectors #(M,4,2)
  317. ##校正偏移向量
  318. vectors,wh_thetas=rectangle_quadrangle(vectors) #vectors--(M,4,2)
  319. ##校正每一个框内四个点的顺序
  320. vectors = adjust_pts_orders(vectors) # (M,4,2)
  321. #将中心点附在偏移向量后面
  322. vectors = np.concatenate( (vectors,middlePts),axis=1 )#(M,5,2),
  323. #将数据拉平
  324. vectors = vectors.reshape(-1,10)#(M,10)
  325. return vectors
  326. def rectangle_quadrangle_batch(obbs):
  327. ##输入出四边形的四个点(M,4,2)
  328. ##输出是矩形话后的4个点(M,4,2)
  329. obbArray = np.array(obbs)#( M,4,2)
  330. #计算中心点
  331. middlePts = np.mean( obbArray,axis=1 )##中心点(M,2)
  332. middlePts = np.expand_dims(middlePts,axis=1)#(M,1,2)
  333. #将中心点扩展成(M,4,2)
  334. middlePts = np.tile(middlePts,(1,4,1))#(M,4,2)
  335. #vectors = np.tile(middlePts,(1,4,1))#(M,4,2)
  336. #计算偏移向量
  337. vectors = obbArray - middlePts #(M,4,2)
  338. ##校正偏移向量
  339. vectors,wh_thetas=rectangle_quadrangle(vectors) #vectors--(M,4,2)
  340. vectors = vectors + middlePts
  341. return vectors
  342. def obbPointsConvert_reverse(vectors):
  343. vectors = np.array(vectors)#(M,10)
  344. _vectors = vectors[:,:8] #(M,8)
  345. middlePts = vectors[:,8:10] #(M,2)
  346. middlePts = np.tile( middlePts,(1,4) ) #(M,8)
  347. _vectors += middlePts #(M,8)
  348. return _vectors
  349. def OBB_tracker_batch(imgarray_list,iframe_list,modelPar,obbModelPar,sort_tracker,trackPar,segPar=None):
  350. '''
  351. 输入:
  352. imgarray_list--图像列表
  353. iframe_list -- 帧号列表
  354. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  355. obbModelpar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  356. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  357. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  358. segPar--None,分割模型相关参数。如果用不到,则为None
  359. 输入:[imgarray_list,track_det_result,detResults ] , timeInfos
  360. # timeInfos---时间信息
  361. # imgarray_list--图像列表
  362. # track_det_result--numpy 格式(M,14)--( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  363. # (x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  364. # detResults---给DSP的结果.每一帧是一个list,内部每一个框时一个list,格式为[ [(x0,y0),(x1,y1),(x2,y2),(x3,y3)],score,cls ] 2023.08.03,修改输出格式
  365. '''
  366. det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
  367. trackers_dic={}
  368. index_list = list(range( 0, len(iframe_list) ,det_cnt ));
  369. if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
  370. index_list.append( len(iframe_list) - 1 )
  371. #print('###line349:',index_list ,iframe_list)
  372. if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
  373. ori_image_list,infos = OBB_infer(modelPar['obbmodel'],imgarray_list[0],obbModelPar)
  374. #print('##'*20,'line405:',np.array(ori_image_list[2]),ret_4pts )
  375. return ori_image_list,infos
  376. else:
  377. timeInfos_track=''
  378. t1=time.time()
  379. for iframe_index, index_frame in enumerate(index_list):
  380. ori_image_list,infos = OBB_infer(modelPar['obbmodel'],imgarray_list[index_frame],obbModelPar)
  381. obbs = [x[0] for x in ori_image_list[2] ];hbbs = []
  382. for i in range(len(ori_image_list[2])):
  383. hbb=obbTohbb( ori_image_list[2][i][0] );
  384. box=[ *hbb, ori_image_list[2][i][1],ori_image_list[2][i][2]]
  385. hbbs.append(box)
  386. tracked_dets = OBB_tracker(sort_tracker,hbbs,obbs,iframe_list[index_frame] )
  387. tracks =sort_tracker.getTrackers()
  388. tt=[tracker.id for tracker in tracks]
  389. for tracker in tracks:
  390. trackers_dic[tracker.id]=deepcopy(tracker)
  391. t2=time.time()
  392. track_det_result = np.empty((0,14))
  393. ###( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  394. ###(x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  395. trackIdIndex=13;frameIndex=12
  396. #print('###line372:',list(trackers_dic.keys()))
  397. for trackId in trackers_dic.keys():
  398. tracker = trackers_dic[trackId]
  399. obb_history = np.array(tracker.obb_history)
  400. hbb_history = np.array(tracker.bbox_history)
  401. #print('#'*20,obb_history.shape )
  402. if len(obb_history)<2:
  403. #print('#'*20, trackId, ' trace Cnt:',len(obb_history))
  404. continue
  405. #原来格式 np.asarray([tl, tr, br, bl], np.int32)--->中心点到tl, tr, br, bl的向量
  406. #print('###line381: 插值转换前 obb_history:',obb_history.shape, ' trackId:',trackId, ' \n' ,obb_history.reshape(-1,8) )
  407. obb_history = obbPointsConvert(obb_history) #(M,10)
  408. #print('###line381: 插值前 obb_history:',obb_history.shape , ' hbb_history[:,4:7]:',hbb_history[:,4:7].shape, ' trackId:',trackId,'\n',obb_history)
  409. arrays_box = np.concatenate( (obb_history,hbb_history[:,4:7]),axis=1)
  410. arrays_box = arrays_box.transpose();frames=hbb_history[:,6]
  411. #frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
  412. #frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
  413. ##如果要最小化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
  414. #inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
  415. ##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
  416. inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
  417. new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
  418. #print('###line389:',trackId, inter_frame_min,inter_frame_max ,frames)
  419. #print(' ##line396: 插值前:' ,arrays_box)
  420. f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
  421. move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
  422. ###将坐标tl, tr, br, bl的向量--->[tl, tr, br, bl]
  423. interpolation_x0s[:,0:8] = obbPointsConvert_reverse(interpolation_x0s[:,0:10] )
  424. #print('##line403: 插值转换后: ',interpolation_x0s.shape, inter_frame_min,inter_frame_max,frames, '\n',interpolation_x0s )
  425. #for im in range(10):
  426. # interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
  427. cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
  428. interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
  429. track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
  430. detResults=[]
  431. for iiframe in iframe_list:
  432. boxes_oneFrame = track_det_result[ track_det_result[:,frameIndex]==iiframe ]
  433. ###( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  434. ###(x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  435. res = [ [ [(b[0],b[1]),(b[2],b[3]),(b[4],b[5]),(b[6],b[7])],b[10],b[11],b[12],b[13] ]
  436. for b in boxes_oneFrame]
  437. detResults.append( res )
  438. t3 = time.time()
  439. timeInfos='%d frames,detect and track:%.1f ,interpolation:%.1f '%( len(index_list), get_ms(t2,t1),get_ms(t3,t2) )
  440. retResults=[imgarray_list,track_det_result,detResults ]
  441. return retResults, timeInfos