You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

527 lines
22KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. sys.path.extend(['../AIlib2/utils'])
  18. #sys.path.extend(['../AIlib2/utils'])
  19. from plots import draw_painting_joint
  20. from copy import deepcopy
  21. from scipy import interpolate
  22. def obbTohbb(obb):
  23. obbarray=np.array(obb)
  24. x0=np.min(obbarray[:,0])
  25. x1=np.max(obbarray[:,0])
  26. y0=np.min(obbarray[:,1])
  27. y1=np.max(obbarray[:,1])
  28. return [x0,y0,x1,y1]
  29. def trt_version():
  30. return trt.__version__
  31. def torch_device_from_trt(device):
  32. if device == trt.TensorLocation.DEVICE:
  33. return torch.device("cuda")
  34. elif device == trt.TensorLocation.HOST:
  35. return torch.device("cpu")
  36. else:
  37. return TypeError("%s is not supported by torch" % device)
  38. def torch_dtype_from_trt(dtype):
  39. if dtype == trt.int8:
  40. return torch.int8
  41. elif trt_version() >= '7.0' and dtype == trt.bool:
  42. return torch.bool
  43. elif dtype == trt.int32:
  44. return torch.int32
  45. elif dtype == trt.float16:
  46. return torch.float16
  47. elif dtype == trt.float32:
  48. return torch.float32
  49. else:
  50. raise TypeError("%s is not supported by torch" % dtype)
  51. def segTrtForward(engine,inputs,contextFlag=False):
  52. if not contextFlag: context = engine.create_execution_context()
  53. else: context=contextFlag
  54. #with engine.create_execution_context() as context:
  55. #input_names=['images'];output_names=['output']
  56. namess=[ engine.get_binding_name(index) for index in range(engine.num_bindings) ]
  57. input_names = [namess[0]];output_names=namess[1:]
  58. batch_size = inputs[0].shape[0]
  59. bindings = [None] * (len(input_names) + len(output_names))
  60. # 创建输出tensor,并分配内存
  61. outputs = [None] * len(output_names)
  62. for i, output_name in enumerate(output_names):
  63. idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
  64. dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
  65. shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
  66. device = torch_device_from_trt(engine.get_location(idx))
  67. output = torch.empty(size=shape, dtype=dtype, device=device)
  68. #print('&'*10,'batch_size:',batch_size , 'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
  69. outputs[i] = output
  70. #print('###line65:',output_name,i,idx,dtype,shape)
  71. bindings[idx] = output.data_ptr()#绑定输出数据指针
  72. for i, input_name in enumerate(input_names):
  73. idx =engine.get_binding_index(input_name)
  74. bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
  75. #print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
  76. context.execute_v2(bindings) # 执行推理
  77. if len(outputs) == 1:
  78. outputs = outputs[0]
  79. return outputs[0]
  80. else:
  81. return outputs
  82. def apply_mask(image, mask, alpha=0.5):
  83. """Apply the given mask to the image.
  84. """
  85. color = np.random.rand(3)
  86. for c in range(3):
  87. image[:, :, c] = np.where(mask == 1,
  88. image[:, :, c] *
  89. (1 - alpha) + alpha * color[c] * 255,
  90. image[:, :, c])
  91. return image
  92. if not os.path.exists('output'):
  93. os.mkdir('output')
  94. saveDir = 'output'
  95. def get_ms(t2,t1):
  96. return (t2-t1)*1000.0
  97. def draw_painting_joint_2(box,img,label_array,score=0.5,color=None,font={ 'line_thickness':None,'boxLine_thickness':None, 'fontSize':None},socre_location="leftTop"):
  98. ###先把中文类别字体赋值到img中
  99. lh, lw, lc = label_array.shape
  100. imh, imw, imc = img.shape
  101. if socre_location=='leftTop':
  102. x0 , y1 = box[0][0],box[0][1]
  103. elif socre_location=='leftBottom':
  104. x0,y1=box[3][0],box[3][1]
  105. else:
  106. print('plot.py line217 ,label_location:%s not implemented '%( socre_location ))
  107. sys.exit(0)
  108. x1 , y0 = x0 + lw , y1 - lh
  109. if y0<0:y0=0;y1=y0+lh
  110. if y1>imh: y1=imh;y0=y1-lh
  111. if x0<0:x0=0;x1=x0+lw
  112. if x1>imw:x1=imw;x0=x1-lw
  113. img[y0:y1,x0:x1,:] = label_array
  114. pts_cls=[(x0,y0),(x1,y1) ]
  115. #把四边形的框画上
  116. box_tl= font['boxLine_thickness'] or round(0.002 * (imh + imw) / 2) + 1
  117. cv2.polylines(img, [box], True,color , box_tl)
  118. ####把英文字符score画到类别旁边
  119. tl = font['line_thickness'] or round(0.002*(imh+imw)/2)+1#line/font thickness
  120. label = ' %.2f'%(score)
  121. tf = max(tl , 1) # font thickness
  122. fontScale = font['fontSize'] or tl * 0.33
  123. t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
  124. #if socre_location=='leftTop':
  125. p1,p2= (pts_cls[1][0], pts_cls[0][1]),(pts_cls[1][0]+t_size[0],pts_cls[1][1])
  126. cv2.rectangle(img, p1 , p2, color, -1, cv2.LINE_AA)
  127. p3 = pts_cls[1][0],pts_cls[1][1]-(lh-t_size[1])//2
  128. cv2.putText(img, label,p3, 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  129. return img
  130. def OBB_infer(model,ori_image,par):
  131. '''
  132. 输出:[img_origin,ori_image, out_box,9999],infos
  133. img_origin---原图
  134. ori_image---画框图
  135. out_box---检测目标框
  136. ---格式如下[ [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score, cls ], [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score ,cls ],........ ],etc
  137. ---[ [ [(1159, 297), [922, 615], [817, 591], [1054, 272]], 0.865605354309082,14],
  138. [[(1330, 0), [1289, 58], [1228, 50], [1270, 0]], 0.3928087651729584,14] #2023.08.03,修改输出格式
  139. ]
  140. 9999---无意义,备用
  141. '''
  142. t1 = time.time()
  143. #ori_image = cv2.imread(impth+folders[i])
  144. t2 = time.time()
  145. img= cv2.resize(ori_image, (par['model_size']))
  146. img_origin = ori_image.copy()
  147. t3 = time.time()
  148. transf2 = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=par['mean'], std=par['std'])])
  149. img_tensor = transf2(img)
  150. img_tensor1=img_tensor.unsqueeze(0) #转成了需要的tensor格式,中心归一化及颜色通道匹配上
  151. t4=time.time()
  152. #print('###line170: resize:%.1f ToTensor-Normal-Destd:%.1f '%(get_ms(t3,t2),get_ms(t4,t3) ), img_origin.shape,img_tensor1.size() )
  153. #img_tensor1= img_tensor1.to( par['device']) #布置到cuda上
  154. img_tensor1 = img_tensor1.cuda()
  155. t5 =time.time()
  156. img_tensor1 = img_tensor1.half() if par['half'] else img_tensor1
  157. if par['saveType']=='trt':
  158. preds= segTrtForward(model,[img_tensor1])
  159. preds=[x[0] for x in preds ]
  160. pr_decs={}
  161. heads=list(par['heads'].keys())
  162. pr_decs={ heads[i]: preds[i] for i in range(len(heads)) }
  163. elif par['saveType']=='pth':
  164. with torch.no_grad(): # no Back propagation
  165. pr_decs = model(img_tensor1) # 前向传播一部分
  166. elif par['saveType']=='onnx':
  167. img=img_tensor1.cpu().numpy().astype(np.float32)
  168. preds = model['sess'].run(None, {model['input_name']: img})
  169. pr_decs={}
  170. heads=list(par['heads'].keys())
  171. pr_decs={ heads[i]: torch.from_numpy(preds[i]) for i in range(len(heads)) }
  172. t6 = time.time()
  173. category=par['labelnames']
  174. #torch.cuda.synchronize(par['device']) # 时间异步变同步
  175. decoded_pts = []
  176. decoded_scores = []
  177. predictions = par['decoder'].ctdet_decode(pr_decs) # 解码
  178. t6_1=time.time()
  179. pts0, scores0 = func_utils.decode_prediction(predictions, category,par['model_size'], par['down_ratio'],ori_image) # 改3
  180. decoded_pts.append(pts0)
  181. decoded_scores.append(scores0)
  182. t7 = time.time()
  183. # nms
  184. results = {cat: [] for cat in category}
  185. # '''这里啊
  186. for cat in category:
  187. if cat == 'background':
  188. continue
  189. pts_cat = []
  190. scores_cat = []
  191. for pts0, scores0 in zip(decoded_pts, decoded_scores):
  192. pts_cat.extend(pts0[cat])
  193. scores_cat.extend(scores0[cat])
  194. pts_cat = np.asarray(pts_cat, np.float32)
  195. scores_cat = np.asarray(scores_cat, np.float32)
  196. if pts_cat.shape[0]:
  197. nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat)
  198. results[cat].extend(nms_results)
  199. t8 = time.time()
  200. height, width, _ = ori_image.shape
  201. # nms
  202. out_box=[]
  203. for cat in category:
  204. if cat == 'background':
  205. continue
  206. result = results[cat]
  207. for pred in result:
  208. score = pred[-1]
  209. cls = category.index(cat)
  210. boxF=[ max(int(x),0) for x in pred[0:8]]
  211. #box_out=[ cls,[ ( boxF[0], boxF[1]),([boxF[2], boxF[3]]), ([boxF[4], boxF[5]]), ([boxF[6], boxF[7]]) ],score]
  212. box_out=[ [ ( boxF[0], boxF[1]),([boxF[2], boxF[3]]), ([boxF[4], boxF[5]]), ([boxF[6], boxF[7]]) ],score,cls]
  213. '''
  214. if par['drawBox']:
  215. tl = np.asarray([pred[0], pred[1]], np.float32)
  216. tr = np.asarray([pred[2], pred[3]], np.float32)
  217. br = np.asarray([pred[4], pred[5]], np.float32)
  218. bl = np.asarray([pred[6], pred[7]], np.float32)
  219. box = np.asarray([tl, tr, br, bl], np.int32)
  220. bgColor=par['rainbows'][cls%len( par['rainbows'])]
  221. label_array =par['label_array'][cls]
  222. font=par['digitWordFont']
  223. label_location=font['label_location']
  224. ori_image=draw_painting_joint(box,ori_image,label_array,score=score,color=bgColor,font=font,socre_location=label_location)
  225. '''
  226. out_box.append(box_out)
  227. t9 = time.time()
  228. t10 = time.time()
  229. infos=' preProcess:%.1f ToGPU:%.1f infer:%.1f decoder:%.1f, corr_change:%.1f nms:%.1f postProcess:%.1f, total process:%.1f '%( get_ms(t4,t2), get_ms(t5,t4),get_ms(t6,t5),get_ms(t6_1,t6),get_ms(t7,t6_1),get_ms(t8,t7) ,get_ms(t9,t8) ,get_ms(t9,t2) )
  230. #'preProcess:%.1f ToGPU:%.1f infer:%.1f decoder:%.1f, corr_change:%.1f nms:%.1f postProcess:%.1f, total process:%.1f '%
  231. #( get_ms(t4,t2), get_ms(t5,t4),get_ms(t6,t5),get_ms(t6_1,t6),get_ms(t7,t6_1), get_ms(t8,t7) ,get_ms(t9,t8) , get_ms(t9,t2) )
  232. if len(out_box) > 0:
  233. ret_4pts = np.array([ x[0] for x in out_box ] )
  234. ret_4pts = rectangle_quadrangle_batch (ret_4pts)
  235. cnt = len(out_box )
  236. for ii in range(cnt):
  237. out_box[ii][0] = ret_4pts[ii]
  238. return [img_origin,ori_image, out_box,9999],infos
  239. def draw_obb(preds,ori_image,par):
  240. for pred in preds:
  241. box = np.asarray(pred[0][0:4],np.int32)
  242. cls = int(pred[2]);score = pred[1]
  243. bgColor=par['rainbows'][cls%len( par['rainbows'])]
  244. label_array =par['label_array'][cls]
  245. font=par['digitWordFont']
  246. label_location=font['label_location']
  247. #print('###line285:',box,cls,score)
  248. ori_image=draw_painting_joint(box,ori_image,label_array,score=score,color=bgColor,font=font,socre_location=label_location)
  249. #cv2.imwrite( 'test.jpg',ori_image )
  250. return ori_image
  251. def OBB_tracker(sort_tracker,hbbs,obbs,iframe):
  252. #sort_tracker--跟踪器
  253. #hbbs--目标的水平框[x0,y0,x1,y1]
  254. #obbs--目标的倾斜框box = np.asarray([tl, tr, br, bl], np.int32)
  255. #返回值:sort_tracker,跟踪器
  256. dets_to_sort = np.empty((0,7), dtype=np.float32)
  257. # NOTE: We send in detected object class too
  258. for x1,y1,x2,y2,conf, detclass in hbbs:
  259. #print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
  260. dets_to_sort = np.vstack((dets_to_sort,
  261. np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
  262. # Run SORT
  263. tracked_dets = deepcopy(sort_tracker.update(dets_to_sort,obbs) )
  264. return tracked_dets
  265. def rectangle_quadrangle(vectors):
  266. ##输入的是四个点偏离中心点的向量,(M,4,2)
  267. ##输出:vectors--修正后的向量(M,4,2)
  268. # wh_thetas--矩形的向量 (M,1,3)[w,h,theta]
  269. distans = np.sqrt(np.sum(vectors**2,axis=2))#(M,4)
  270. mean_dis = np.mean( distans,axis=1 ).reshape(-1,1) #(M,1)
  271. mean_dis = np.tile(mean_dis,(1,4) ) #(M,4)
  272. scale_factors = mean_dis/distans #(M,4)
  273. scale_factors = np.expand_dims(scale_factors, axis=2 ) #(M,4,1)
  274. scale_factors = np.tile(scale_factors, (1,1,2) ) #M(M,4,2)
  275. vectors = vectors*scale_factors
  276. vectors = vectors.astype(np.int32)
  277. cnt = vectors.shape[0]
  278. boxes = [ cv2.minAreaRect( vectors[i] ) for i in range(cnt) ]
  279. wh_thetas = [[x[1][0],x[1][1],x[2] ] for x in boxes]#(M,3),[w,h,theta]
  280. wh_thetas = np.array(wh_thetas)##(M,3)
  281. return vectors,wh_thetas
  282. def adjust_pts_orders(vectors):
  283. #输入一系列(M,4,2)点
  284. #输入原定框顺序的(M,4,2)
  285. #前后两个四边形框一次判定,调整下一个四边形框内四个点的顺序,保证与上一个一致。
  286. cnt = vectors.shape[0]
  287. if cnt<=1: return vectors
  288. else:
  289. out=[];out.append(vectors[0])
  290. for i in range(1,cnt):
  291. pts1 = out[-1]
  292. pts2 = vectors[i]
  293. diss,min_dis,min_index,pts2_adjust = pts_setDistance(pts1,pts2)
  294. #if min_index!=0: print(min_index,pts1,pts2 )
  295. out.append(pts2_adjust)
  296. out = np.array(out)
  297. #if out[4,0,0]==53 and out[4,0,1]==10:
  298. #print('#line339:',out.shape ,' ','in ', vectors.reshape(-1,8) , ' out :',out.reshape(-1,8))
  299. return out
  300. def pts_setDistance(pts1,pts2):
  301. #输入是两个四边形的坐标(4,2),pts1保持不变,pts2逐个调整顺序,找到与pts2最匹配的四个点。
  302. #输出pts2 原始的距离,最匹配点的距离,最匹配的点的序号
  303. pts3=np.vstack((pts2,pts2))
  304. diss =[np.sum((pts1-pts3[i:i+4])**2) for i in range(4)]
  305. min_dis = min(diss)
  306. min_index = diss.index(min_dis)
  307. return diss[0],min_dis,min_index,pts3[min_index:min_index+4]
  308. def obbPointsConvert(obbs):
  309. obbArray = np.array(obbs)#( M,4,2)
  310. #计算中心点
  311. middlePts = np.mean( obbArray,axis=1 )##中心点(M,2)
  312. middlePts = np.expand_dims(middlePts,axis=1)#(M,1,2)
  313. #将中心点扩展成(M,4,2)
  314. vectors = np.tile(middlePts,(1,4,1))#(M,4,2)
  315. #计算偏移向量
  316. vectors = obbArray - vectors #(M,4,2)
  317. ##校正偏移向量
  318. vectors,wh_thetas=rectangle_quadrangle(vectors) #vectors--(M,4,2)
  319. ##校正每一个框内四个点的顺序
  320. vectors = adjust_pts_orders(vectors) # (M,4,2)
  321. #将中心点附在偏移向量后面
  322. vectors = np.concatenate( (vectors,middlePts),axis=1 )#(M,5,2),
  323. #将数据拉平
  324. vectors = vectors.reshape(-1,10)#(M,10)
  325. return vectors
  326. def rectangle_quadrangle_batch(obbs):
  327. ##输入出四边形的四个点(M,4,2)
  328. ##输出是矩形话后的4个点(M,4,2)
  329. obbArray = np.array(obbs)#( M,4,2)
  330. #计算中心点
  331. middlePts = np.mean( obbArray,axis=1 )##中心点(M,2)
  332. middlePts = np.expand_dims(middlePts,axis=1)#(M,1,2)
  333. #将中心点扩展成(M,4,2)
  334. middlePts = np.tile(middlePts,(1,4,1))#(M,4,2)
  335. #vectors = np.tile(middlePts,(1,4,1))#(M,4,2)
  336. #计算偏移向量
  337. vectors = obbArray - middlePts #(M,4,2)
  338. ##校正偏移向量
  339. vectors,wh_thetas=rectangle_quadrangle(vectors) #vectors--(M,4,2)
  340. vectors = vectors + middlePts
  341. return vectors
  342. def obbPointsConvert_reverse(vectors):
  343. vectors = np.array(vectors)#(M,10)
  344. _vectors = vectors[:,:8] #(M,8)
  345. middlePts = vectors[:,8:10] #(M,2)
  346. middlePts = np.tile( middlePts,(1,4) ) #(M,8)
  347. _vectors += middlePts #(M,8)
  348. return _vectors
  349. def OBB_tracker_batch(imgarray_list,iframe_list,modelPar,obbModelPar,sort_tracker,trackPar,segPar=None):
  350. '''
  351. 输入:
  352. imgarray_list--图像列表
  353. iframe_list -- 帧号列表
  354. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  355. obbModelpar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  356. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  357. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  358. segPar--None,分割模型相关参数。如果用不到,则为None
  359. 输入:[imgarray_list,track_det_result,detResults ] , timeInfos
  360. # timeInfos---时间信息
  361. # imgarray_list--图像列表
  362. # track_det_result--numpy 格式(M,14)--( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  363. # (x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  364. # detResults---给DSP的结果.每一帧是一个list,内部每一个框时一个list,格式为[ [(x0,y0),(x1,y1),(x2,y2),(x3,y3)],score,cls ] 2023.08.03,修改输出格式
  365. '''
  366. det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
  367. trackers_dic={}
  368. index_list = list(range( 0, len(iframe_list) ,det_cnt ));
  369. if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
  370. index_list.append( len(iframe_list) - 1 )
  371. #print('###line349:',index_list ,iframe_list)
  372. if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
  373. ori_image_list,infos = OBB_infer(modelPar['obbmodel'],imgarray_list[0],obbModelPar)
  374. #print('##'*20,'line405:',np.array(ori_image_list[2]),ret_4pts )
  375. return ori_image_list,infos
  376. else:
  377. timeInfos_track=''
  378. t1=time.time()
  379. for iframe_index, index_frame in enumerate(index_list):
  380. ori_image_list,infos = OBB_infer(modelPar['obbmodel'],imgarray_list[index_frame],obbModelPar)
  381. obbs = [x[0] for x in ori_image_list[2] ];hbbs = []
  382. for i in range(len(ori_image_list[2])):
  383. hbb=obbTohbb( ori_image_list[2][i][0] );
  384. box=[ *hbb, ori_image_list[2][i][1],ori_image_list[2][i][2]]
  385. hbbs.append(box)
  386. tracked_dets = OBB_tracker(sort_tracker,hbbs,obbs,iframe_list[index_frame] )
  387. tracks =sort_tracker.getTrackers()
  388. tt=[tracker.id for tracker in tracks]
  389. for tracker in tracks:
  390. trackers_dic[tracker.id]=deepcopy(tracker)
  391. t2=time.time()
  392. track_det_result = np.empty((0,14))
  393. ###( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  394. ###(x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  395. trackIdIndex=13;frameIndex=12
  396. #print('###line372:',list(trackers_dic.keys()))
  397. for trackId in trackers_dic.keys():
  398. tracker = trackers_dic[trackId]
  399. obb_history = np.array(tracker.obb_history)
  400. hbb_history = np.array(tracker.bbox_history)
  401. #print('#'*20,obb_history.shape )
  402. if len(obb_history)<2:
  403. #print('#'*20, trackId, ' trace Cnt:',len(obb_history))
  404. continue
  405. #原来格式 np.asarray([tl, tr, br, bl], np.int32)--->中心点到tl, tr, br, bl的向量
  406. #print('###line381: 插值转换前 obb_history:',obb_history.shape, ' trackId:',trackId, ' \n' ,obb_history.reshape(-1,8) )
  407. obb_history = obbPointsConvert(obb_history) #(M,10)
  408. #print('###line381: 插值前 obb_history:',obb_history.shape , ' hbb_history[:,4:7]:',hbb_history[:,4:7].shape, ' trackId:',trackId,'\n',obb_history)
  409. arrays_box = np.concatenate( (obb_history,hbb_history[:,4:7]),axis=1)
  410. arrays_box = arrays_box.transpose();frames=hbb_history[:,6]
  411. #frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
  412. #frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
  413. ##如果要最小化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
  414. #inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
  415. ##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
  416. inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
  417. new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
  418. #print('###line389:',trackId, inter_frame_min,inter_frame_max ,frames)
  419. #print(' ##line396: 插值前:' ,arrays_box)
  420. f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
  421. move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
  422. ###将坐标tl, tr, br, bl的向量--->[tl, tr, br, bl]
  423. interpolation_x0s[:,0:8] = obbPointsConvert_reverse(interpolation_x0s[:,0:10] )
  424. #print('##line403: 插值转换后: ',interpolation_x0s.shape, inter_frame_min,inter_frame_max,frames, '\n',interpolation_x0s )
  425. #for im in range(10):
  426. # interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
  427. cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
  428. interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
  429. track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
  430. detResults=[]
  431. for iiframe in iframe_list:
  432. boxes_oneFrame = track_det_result[ track_det_result[:,frameIndex]==iiframe ]
  433. ###( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 , 11, 12, 13 )
  434. ###(x0,y0,x1,y1,x2,y2,x3,y3,xc,yc,conf, detclass,iframe, trackId)
  435. res = [ [ [(b[0],b[1]),(b[2],b[3]),(b[4],b[5]),(b[6],b[7])],b[10],b[11],b[12],b[13] ]
  436. for b in boxes_oneFrame]
  437. detResults.append( res )
  438. t3 = time.time()
  439. timeInfos='%d frames,detect and track:%.1f ,interpolation:%.1f '%( len(index_list), get_ms(t2,t1),get_ms(t3,t2) )
  440. retResults=[imgarray_list,track_det_result,detResults ]
  441. return retResults, timeInfos