You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

queRiver.py 13KB

2 vuotta sitten
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from kafka import KafkaProducer, KafkaConsumer
  2. from kafka.errors import kafka_errors
  3. import traceback
  4. import json, base64,os
  5. import numpy as np
  6. from multiprocessing import Process,Queue
  7. import time,cv2,string,random
  8. import subprocess as sp
  9. import matplotlib.pyplot as plt
  10. from utils.datasets import LoadStreams, LoadImages
  11. from models.experimental import attempt_load
  12. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  13. import torch,sys
  14. #from segutils.segmodel import SegModel,get_largest_contours
  15. #sys.path.extend(['../yolov5/segutils'])
  16. from segutils.segWaterBuilding import SegModel,get_largest_contours,illBuildings
  17. #from segutils.core.models.bisenet import BiSeNet
  18. from segutils.core.models.bisenet import BiSeNet_MultiOutput
  19. from utils.plots import plot_one_box,plot_one_box_PIL,draw_painting_joint,get_label_arrays,get_websource
  20. from collections import Counter
  21. #import matplotlib
  22. import matplotlib.pyplot as plt
  23. # get_labelnames,get_label_arrays,post_process_,save_problem_images,time_str
  24. FP_DEBUG=open('debut.txt','w')
  25. def bsJpgCode(image_ori):
  26. jpgCode = cv2.imencode('.jpg',image_ori)[-1]###np.array,(4502009,1)
  27. bsCode = str(base64.b64encode(jpgCode))[2:-1] ###str,长6002680
  28. return bsCode
  29. def bsJpgDecode(bsCode):
  30. bsDecode = base64.b64decode(bsCode)###types,长4502009
  31. npString = np.frombuffer(bsDecode,np.uint8)###np.array,(长4502009,)
  32. jpgDecode = cv2.imdecode(npString,cv2.IMREAD_COLOR)###np.array,(3000,4000,3)
  33. return jpgDecode
  34. def get_ms(time0,time1):
  35. str_time ='%.2f ms'%((time1-time0)*1000)
  36. return str_time
  37. rainbows=[
  38. (0,0,255),(0,255,0),(255,0,0),(255,0,255),(255,255,0),(255,127,0),(255,0,127),
  39. (127,255,0),(0,255,127),(0,127,255),(127,0,255),(255,127,255),(255,255,127),
  40. (127,255,255),(0,255,255),(255,127,255),(127,255,255),
  41. (0,127,0),(0,0,127),(0,255,255)
  42. ]
  43. def get_labelnames(labelnames):
  44. with open(labelnames,'r') as fp:
  45. namesjson=json.load(fp)
  46. names_fromfile=namesjson['labelnames']
  47. names = names_fromfile
  48. return names
  49. def check_stream(stream):
  50. cap = cv2.VideoCapture(stream)
  51. if cap.isOpened():
  52. return True
  53. else:
  54. return False
  55. #####
  56. def drawWater(pred,image_array0):####pred是模型的输出,只有水分割的任务
  57. ##画出水体区域
  58. contours, hierarchy = cv2.findContours(pred,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  59. water = pred.copy(); water[:,:] = 0
  60. if len(contours)==0:
  61. return image_array0,water
  62. max_id = get_largest_contours(contours);
  63. cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
  64. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  65. return image_array0,water
  66. def post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe,object_config=[0,1,2,3,4]):
  67. ##输入dataset genereate 生成的数据,model预测的结果pred,nms参数
  68. ##主要操作NMS ---> 坐标转换 ---> 画图
  69. ##输出原图、AI处理后的图、检测结果
  70. time0=time.time()
  71. path, img, im0s, vid_cap ,pred,seg_pred= datas[0:6];
  72. segmodel=True
  73. pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
  74. time1=time.time()
  75. i=0;det=pred[0]###一次检测一张图片
  76. p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  77. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  78. det_xywh=[];
  79. #im0_brg=cv2.cvtColor(im0,cv2.COLOR_RGB2BGR);
  80. if len(seg_pred)==2:
  81. im0,water = illBuildings(seg_pred,im0)
  82. else:
  83. im0,water = drawWater(seg_pred,im0)
  84. time2=time.time()
  85. #plt.imshow(im0);plt.show()
  86. if len(det)>0:
  87. # Rescale boxes from img_size to im0 size
  88. det[:, :4] = scale_coords(img.shape[2:], det[:, :4],im0.shape).round()
  89. #用seg模型,确定有效检测匡及河道轮廓线
  90. if segmodel:
  91. '''contours, hierarchy = cv2.findContours(seg_pred,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  92. if len(contours)>0:
  93. max_id = get_largest_contours(contours)
  94. seg_pred[:,:] = 0
  95. cv2.fillPoly(seg_pred, [contours[max_id][:,0,:]], 1)
  96. cv2.drawContours(im0,contours,max_id,(0,255,255),3)'''
  97. det_c = det.clone(); det_c=det_c.cpu().numpy()
  98. area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )/((x[2]-x[0])*(x[3]-x[1])) for x in det_c] )
  99. det = det[area_factors>0.1]
  100. #对检测匡绘图
  101. for *xyxy, conf, cls in reversed(det):
  102. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  103. cls_c = cls.cpu().numpy()
  104. if int(cls_c) not in object_config: ###如果不是所需要的目标,则不显示
  105. continue
  106. conf_c = conf.cpu().numpy()
  107. line = [float(cls_c), *xywh, float(conf_c)] # label format
  108. det_xywh.append(line)
  109. label = f'{names[int(cls)]} {conf:.2f}'
  110. im0 = draw_painting_joint(xyxy,im0,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],line_thickness=None)
  111. time3=time.time()
  112. strout='nms:%s illBuilding:%s detDraw:%s '%(get_ms(time0,time1),get_ms(time1,time2), get_ms(time2,time3) )
  113. return [im0s[0],im0,det_xywh,iframe],strout
  114. def preprocess(par):
  115. print('#####process:',par['name'])
  116. ##负责读取视频,生成原图及供检测的使用图,numpy格式
  117. #source='rtmp://liveplay.yunhengzhizao.cn/live/demo_HD5M'
  118. #img_size=640; stride=32
  119. while True:
  120. cap = cv2.VideoCapture(par['source'])
  121. iframe = 0
  122. if cap.isOpened():
  123. print( '#### read %s success!'%(par['source']))
  124. try:
  125. dataset = LoadStreams(par['source'], img_size=640, stride=32)
  126. for path, img, im0s, vid_cap in dataset:
  127. datas=[path, img, im0s, vid_cap,iframe]
  128. par['queOut'].put(datas)
  129. iframe +=1
  130. except Exception as e:
  131. print('###read error:%s '%(par['source']))
  132. time.sleep(10)
  133. iframe = 0
  134. else:
  135. print('###read error:%s '%(par['source'] ))
  136. time.sleep(10)
  137. iframe = 0
  138. def gpu_process(par):
  139. print('#####process:',par['name'])
  140. half=True
  141. ##gpu运算,检测模型
  142. weights = par['weights']
  143. device = par['device']
  144. print('###line127:',par['device'])
  145. model = attempt_load(par['weights'], map_location=par['device']) # load FP32 model
  146. if half:
  147. model.half()
  148. ##gpu运算,分割模型
  149. seg_nclass = par['seg_nclass']
  150. seg_weights = par['seg_weights']
  151. #segmodel = SegModel(nclass=seg_nclass,weights=seg_weights,device=device)
  152. nclass = [2,2]
  153. Segmodel = BiSeNet_MultiOutput(nclass)
  154. weights='weights/segmentation/WaterBuilding.pth'
  155. segmodel = SegModel(model=Segmodel,nclass=nclass,weights=weights,device='cuda:0',multiOutput=True)
  156. while True:
  157. if not par['queIn'].empty():
  158. time0=time.time()
  159. datas = par['queIn'].get()
  160. path, img, im0s, vid_cap,iframe = datas[0:5]
  161. time1=time.time()
  162. img = torch.from_numpy(img).to(device)
  163. img = img.half() if half else img.float() # uint8 to fp16/32
  164. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  165. time2 = time.time()
  166. pred = model(img,augment=False)[0]
  167. time3 = time.time()
  168. seg_pred = segmodel.eval(im0s[0],outsize=None,smooth_kernel=20)
  169. time4 = time.time()
  170. fpStr= 'process:%s ,iframe:%d,getdata:%s,copygpu:%s,dettime:%s,segtime:%s , time:%s, queLen:%d '%( par['name'],iframe,get_ms(time0,time1) ,get_ms(time1,time2) ,get_ms(time2,time3) ,get_ms(time3,time4),get_ms(time0,time4) ,par['queIn'].qsize() )
  171. FP_DEBUG.write( fpStr+'\n' )
  172. datasOut = [path, img, im0s, vid_cap,pred,seg_pred,iframe]
  173. par['queOut'].put(datasOut)
  174. if par['debug']:
  175. print('#####process:',par['name'],' line107')
  176. else:
  177. time.sleep(1/300)
  178. def get_cls(array):
  179. dcs = Counter(array)
  180. keys = list(dcs.keys())
  181. values = list(dcs.values())
  182. max_index = values.index(max(values))
  183. cls = int(keys[max_index])
  184. return cls
  185. def save_problem_images(post_results,iimage_cnt,names,streamName='live-THSAHD5M',outImaDir='problems/images_tmp',imageTxtFile=False):
  186. ## [cls, x,y,w,h, conf]
  187. problem_image=[[] for i in range(6)]
  188. dets_list = [x[2] for x in post_results]
  189. mean_scores=[ np.array(x)[:,5].mean() for x in dets_list ] ###mean conf
  190. best_index = mean_scores.index(max(mean_scores)) ##获取该批图片里,问题图片的index
  191. best_frame = post_results[ best_index][3] ##获取绝对帧号
  192. img_send = post_results[best_index][1]##AI处理后的图
  193. img_bak = post_results[best_index][0]##原图
  194. cls_max = get_cls( x[0] for x in dets_list[best_index] )
  195. time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
  196. uid=''.join(random.sample(string.ascii_letters + string.digits, 16))
  197. #ori_name = '2022-01-20-15-57-36_frame-368-720_type-漂浮物_qVh4zI08ZlwJN9on_s-live-THSAHD5M_OR.jpg'
  198. #2022-01-13-15-07-57_frame-9999-9999_type-结束_9999999999999999_s-off-XJRW20220110115904_AI.jpg
  199. outnameOR= '%s/%s_frame-%d-%d_type-%s_%s_s-%s_AI.jpg'%(outImaDir,time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
  200. outnameAR= '%s/%s_frame-%d-%d_type-%s_%s_s-%s_OR.jpg'%(outImaDir,time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
  201. cv2.imwrite(outnameOR,img_send)
  202. cv2.imwrite(outnameAR,img_bak)
  203. if imageTxtFile:
  204. outnameOR_txt = outnameOR.replace('.jpg','.txt')
  205. fp=open(outnameOR_txt,'w');fp.write(outnameOR+'\n');fp.close()
  206. outnameAI_txt = outnameAR.replace('.jpg','.txt')
  207. fp=open(outnameAI_txt,'w');fp.write(outnameAR+'\n');fp.close()
  208. parOut = {}; parOut['imgOR'] = img_send; parOut['imgAR'] = img_send; parOut['uid']=uid
  209. parOut['imgORname']=os.path.basename(outnameOR);parOut['imgARname']=os.path.basename(outnameAR);
  210. parOut['time_str'] = time_str;parOut['type'] = names[cls_max]
  211. return parOut
  212. def post_process(par):
  213. print('#####process:',par['name'])
  214. ###post-process参数
  215. conf_thres,iou_thres,classes=par['conf_thres'],par['iou_thres'],par['classes']
  216. labelnames=par['labelnames']
  217. rainbows=par['rainbows']
  218. fpsample = par['fpsample']
  219. names=get_labelnames(labelnames)
  220. label_arraylist = get_label_arrays(names,rainbows,outfontsize=40)
  221. iimage_cnt = 0
  222. post_results=[]
  223. while True:
  224. if not par['queIn'].empty():
  225. time0=time.time()
  226. datas = par['queIn'].get()
  227. iframe = datas[6]
  228. if par['debug']:
  229. print('#####process:',par['name'],' line129')
  230. p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe)
  231. par['queOut'].put(p_result)
  232. ##输出结果
  233. ##每隔 fpsample帧处理一次,如果有问题就保存图片
  234. if (iframe % fpsample == 0) and (len(post_results)>0) :
  235. #print('####line204:',iframe,post_results)
  236. save_problem_images(post_results,iframe,names)
  237. post_results=[]
  238. if len(p_result[2] )>0: ##
  239. #post_list = p_result.append(iframe)
  240. post_results.append(p_result)
  241. #print('####line201:',type(p_result))
  242. time1=time.time()
  243. outstr='process:%s ,iframe:%d,%s , time:%s, queLen:%d '%( par['name'],iframe,timeOut,get_ms(time0,time1) ,par['queIn'].qsize() )
  244. FP_DEBUG.write(outstr +'\n')
  245. #print( 'process:%s ,iframe:%d,%s , time:%s, queLen:%d '%( par['name'],iframe,timeOut,get_ms(time0,time1) ,par['queIn'].qsize() ) )
  246. else:
  247. time.sleep(1/300)
  248. def save_logfile(name,txt):
  249. if os.path.exists(name):
  250. fp=open(name,'r+')
  251. else:
  252. fp=open(name,'w')
  253. fp.write('%s %s \n'%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),txt))
  254. fp.close()
  255. def time_str():
  256. return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  257. if __name__=='__main__':
  258. jsonfile='config/queRiver.json'
  259. #image_encode_decode()
  260. work_stream(jsonfile)
  261. #par={'name':'preprocess'}
  262. #preprocess(par)