diff --git a/AI.py b/AI.py index 7d39aea..25c4809 100644 --- a/AI.py +++ b/AI.py @@ -5,11 +5,13 @@ from segutils.trtUtils import segtrtEval,yolov5Trtforward,OcrTrtForward from segutils.trafficUtils import tracfficAccidentMixFunction from utils.torch_utils import select_device -from utilsK.queRiver import get_labelnames,get_label_arrays,post_process_,img_pad,draw_painting_joint,detectDraw,getDetections,getDetectionsFromPreds +from utilsK.queRiver import get_labelnames, img_pad, getDetections, getDetectionsFromPreds, scale_back from utilsK.jkmUtils import pre_process, post_process, get_return_data from trackUtils.sort import moving_average_wang from utils.datasets import letterbox +from utils.general import non_max_suppression, scale_coords,xyxy2xywh,overlap_box_suppression +from utils.plots import draw_painting_joint,get_label_arrays import numpy as np import torch import math @@ -18,6 +20,7 @@ import torch.nn.functional as F from copy import deepcopy from scipy import interpolate import glob +from loguru import logger def get_images_videos(impth, imageFixs=['.jpg','.JPG','.PNG','.png'],videoFixs=['.MP4','.mp4','.avi']): imgpaths=[];###获取文件里所有的图像 @@ -33,9 +36,9 @@ def get_images_videos(impth, imageFixs=['.jpg','.JPG','.PNG','.png'],videoFixs=[ if postfix in videoFixs: videopaths = [impth ] print('%s: test Images:%d , test videos:%d '%(impth, len(imgpaths), len(videopaths))) - return imgpaths,videopaths + return imgpaths,videopaths -def xywh2xyxy(box,iW=None,iH=None): +def xywh2xy(box,iW=None,iH=None): xc,yc,w,h = box[0:4] x0 =max(0, xc-w/2.0) x1 =min(1, xc+w/2.0) @@ -73,13 +76,15 @@ def score_filter_byClass(pdetections,score_para_2nd): ret.append(det) return ret # 按类过滤 -def filter_byClass(pdetections,allowedList): - ret=[] +def filter_byClass(pdetections, fiterList): + ret = [] for det in pdetections: - score,cls = det[4],det[5] - if int(cls) in allowedList: - ret.append(det) - elif str(int(cls)) in allowedList: + score, cls = det[4], det[5] + if int(cls) in fiterList: + continue + elif str(int(cls)) in fiterList: + continue + else: ret.append(det) return ret @@ -111,9 +116,90 @@ def plat_format(ocr): return label.upper() -def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False,'score_byClass':{x:0.1 for x in range(30)} }, font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,segPar={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True},mode='others',postPar=None): +def post_process_det(pred,padInfos,img,im0s,conf_thres,iou_thres,label_arraylist,rainbows,font,score_byClass,fiterList,ovlap_thres=None): + time0 = time.time() + pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False) + if ovlap_thres: + pred = overlap_box_suppression(pred, ovlap_thres) + time1 = time.time() + det = pred[0] ###一次检测一张图片 + det_xywh = []; + im0 = im0s.copy() + #im0 = im0s[0] + if len(det) > 0: + # Rescale boxes from img_size to im0 size + if not padInfos: + det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() + else: + # print('####line131:',det[:, :]) + det[:, :4] = scale_back(det[:, :4], padInfos).round() - #输入参数 + + for *xyxy, conf, cls in reversed(det): + cls_c = cls.cpu().numpy() + conf_c = conf.cpu().numpy() + tt = [int(x.cpu()) for x in xyxy] + if fiterList: + if int(cls) in fiterList: ###如果不是所需要的目标,则不显示 + continue + if score_byClass: + if int(cls) in score_byClass.keys(): + if conf < score_byClass[int(cls)]: + continue + line = [*tt, float(conf_c), float(cls_c)] # label format + det_xywh.append(line) + time2 = time.time() + strout='nms:%s ,detDraw:%s '%(get_ms(time0,time1), get_ms(time1,time2) ) + return [im0s[0],im0s[0], det_xywh, 10],strout + +def post_process_seg(im0s,segmodel,boxes,ksize): + time0 = time.time() + im0 = im0s[0].copy() + segmodel.set_image(im0s[0]) + # # 创建一个空白掩码用于保存所有火焰 + # combined_mask = np.zeros((im0.shape[0], im0.shape[1]), dtype=np.uint8) + # # 创建边缘可视化图像 + # edge_image = np.zeros_like(im0s[0]) + # 处理每个火焰检测框 + det_xywhP = [] + for box in boxes: + x_min, y_min, x_max, y_max = box[:4] + # 转换为SAM需要的格式 + input_box = np.array([x_min, y_min, x_max, y_max]) + + # 使用框提示进行分割 + masks, _, _ = segmodel.predict( + box=input_box, + multimask_output=False # 只返回最佳掩码 + ) + + # 获取分割掩码 + flame_mask = masks[0].astype(np.uint8) + # 使用形态学操作填充小孔洞 + filled_mask = cv2.morphologyEx(flame_mask, cv2.MORPH_CLOSE, ksize) + # 查找所有轮廓(包括内部小点) + contours, _ = cv2.findContours(filled_mask.astype(np.uint8), + cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + continue + largest_contour = max(contours, key=cv2.contourArea) + # 通过轮廓填充。 + #cv2.drawContours(im0, [largest_contour], -1, (0, 0, 255), 2) + box.append(largest_contour) + det_xywhP.append(box) + time1 = time.time() + strout = 'segDraw:%s ' % get_ms(time0, time1) + return [im0s[0], im0s[0], det_xywhP, 10], strout + +def AI_process(im0s, model, segmodel, names, label_arraylist, rainbows, + objectPar={'half': True, 'device': 'cuda:0', 'conf_thres': 0.25, 'iou_thres': 0.45, + 'segRegionCnt': 1, 'trtFlag_det': False,'trtFlag_seg': False,'score_byClass':None,'fiterList':[]}, + font={'line_thickness': None, 'fontSize': None, 'boxLine_thickness': None, + 'waterLineColor': (0, 255, 255), 'waterLineWidth': 3}, + segPar={'modelSize': (640, 360), 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), + 'numpy': False, 'RGB_convert_first': True}, mode='others', postPar=None): + # 输入参数 # im0s---原始图像列表 # model---检测模型,segmodel---分割模型(如若没有用到,则为None) # @@ -126,16 +212,13 @@ def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'h # #strout---统计AI处理个环节的时间 # Letterbox - half,device,conf_thres,iou_thres,allowedList = objectPar['half'],objectPar['device'],objectPar['conf_thres'],objectPar['iou_thres'],objectPar['allowedList'] + half, device, conf_thres, iou_thres, fiterList,score_byClass = objectPar['half'], objectPar['device'], objectPar['conf_thres'], \ + objectPar['iou_thres'], objectPar['fiterList'], objectPar['score_byClass'] - trtFlag_det,trtFlag_seg,segRegionCnt = objectPar['trtFlag_det'],objectPar['trtFlag_seg'],objectPar['segRegionCnt'] - if 'ovlap_thres_crossCategory' in objectPar.keys(): ovlap_thres = objectPar['ovlap_thres_crossCategory'] - else: ovlap_thres = None + trtFlag_det, trtFlag_seg, segRegionCnt = objectPar['trtFlag_det'], objectPar['trtFlag_seg'], objectPar[ + 'segRegionCnt'] - if 'score_byClass' in objectPar.keys(): score_byClass = objectPar['score_byClass'] - else: score_byClass = None - - time0=time.time() + time0 = time.time() if trtFlag_det: img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img] else: @@ -151,10 +234,10 @@ def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'h img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() # uint8 to fp16/32 img /= 255.0 - time01=time.time() + time01 = time.time() - if segmodel: - seg_pred,segstr = segmodel.eval(im0s[0] ) + if segmodel: + seg_pred,segstr = segmodel.eval(im0s[0]) segFlag=True else: seg_pred = None;segFlag=False;segstr='Not implemented' @@ -170,12 +253,11 @@ def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'h time2=time.time() - p_result, timeOut = getDetectionsFromPreds(pred,img,im0s[0],conf_thres=conf_thres,iou_thres=iou_thres,ovlap_thres=ovlap_thres,padInfos=padInfos) - if score_byClass: - p_result[2] = score_filter_byClass(p_result[2],score_byClass) - #if mode=='highWay3.0': - #if segmodel: - if segPar and segPar['mixFunction']['function']: + p_result, timeOut = getDetectionsFromPreds(pred, img, im0s[0], conf_thres=conf_thres, iou_thres=iou_thres, + ovlap_thres=None, padInfos=padInfos) + # if mode=='highWay3.0': + # if segmodel: + if segPar and segPar['mixFunction']['function']: mixFunction = segPar['mixFunction']['function'];H,W = im0s[0].shape[0:2] parMix = segPar['mixFunction']['pars'];#print('###line117:',parMix,p_result[2]) @@ -186,19 +268,26 @@ def AI_process(im0s,model,segmodel,names,label_arraylist,rainbows,objectPar={ 'h p_result.append(seg_pred) else: - timeMixPost=':0 ms' - #print('#### line121: segstr:%s timeMixPost:%s timeOut:%s'%( segstr.strip(), timeMixPost,timeOut )) - time_info = 'letterbox:%.1f, seg:%.1f , infer:%.1f,%s, seginfo:%s ,timeMixPost:%s '%( (time01-time0)*1000, (time1-time01)*1000 ,(time2-time1)*1000,timeOut , segstr.strip(),timeMixPost ) - if allowedList: - p_result[2] = filter_byClass(p_result[2],allowedList) + timeMixPost = ':0 ms' + # print('#### line121: segstr:%s timeMixPost:%s timeOut:%s'%( segstr.strip(), timeMixPost,timeOut )) + time_info = 'letterbox:%.1f, seg:%.1f , infer:%.1f,%s, seginfo:%s ,timeMixPost:%s ' % ( + (time01 - time0) * 1000, (time1 - time01) * 1000, (time2 - time1) * 1000, timeOut, segstr.strip(), timeMixPost) + if fiterList: + p_result[2] = filter_byClass(p_result[2], fiterList) - print('-'*10,p_result[2]) - return p_result,time_info -def default_mix(predlist,par): - return predlist[0],'' -def AI_process_N(im0s,modelList,postProcess): + if score_byClass: + p_result[2] = score_filter_byClass(p_result[2], score_byClass) - #输入参数 + print('-' * 10, p_result[2]) + return p_result, time_info + + +def default_mix(predlist, par): + return predlist[0], '' + + +def AI_process_N(im0s, modelList, postProcess,score_byClass=None,fiterList=[]): + # 输入参数 ## im0s---原始图像列表 ## modelList--所有的模型 # postProcess--字典{},包括后处理函数,及其参数 @@ -206,22 +295,28 @@ def AI_process_N(im0s,modelList,postProcess): ##ret[0]--检测结果; ##ret[1]--时间信息 - #modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果 - modelRets=[ model.eval(im0s[0]) for model in modelList] + # modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果 + modelRets = [model.eval(im0s[0]) for model in modelList] - timeInfos = [ x[1] for x in modelRets] - timeInfos=''.join(timeInfos) - timeInfos=timeInfos + timeInfos = [x[1] for x in modelRets] + timeInfos = ''.join(timeInfos) + timeInfos = timeInfos - #postProcess['function']--后处理函数,输入的就是所有模型输出结果 - mixFunction =postProcess['function'] - predsList = [ modelRet[0] for modelRet in modelRets ] - H,W = im0s[0].shape[0:2] - postProcess['pars']['imgSize'] = (W,H) + # postProcess['function']--后处理函数,输入的就是所有模型输出结果 + mixFunction = postProcess['function'] + predsList = [modelRet[0] for modelRet in modelRets] + H, W = im0s[0].shape[0:2] + postProcess['pars']['imgSize'] = (W, H) - #ret就是混合处理后的结果 - ret = mixFunction( predsList, postProcess['pars']) - return ret[0],timeInfos+ret[1] + # ret就是混合处理后的结果 + ret = mixFunction(predsList, postProcess['pars']) + det = ret[0] + if fiterList: + det = filter_byClass(det, fiterList) + if score_byClass: + det = score_filter_byClass(det, score_byClass) + + return det, timeInfos + ret[1] def getMaxScoreWords(detRets0): maxScore=-1;maxId=0 for i,detRet in enumerate(detRets0): @@ -230,8 +325,8 @@ def getMaxScoreWords(detRets0): maxScore = detRet[4] return maxId -def AI_process_C(im0s,modelList,postProcess): - #函数定制的原因: +def AI_process_C(im0s, modelList, postProcess,score_byClass,fiterList): + # 函数定制的原因: ## 之前模型处理流是 ## 图片---> 模型1-->result1;图片---> 模型2->result2;[result1,result2]--->后处理函数 ## 本函数的处理流程是 @@ -287,21 +382,30 @@ def AI_process_C(im0s,modelList,postProcess): res_real = detRets1[0][0] res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real))) - #detRets1[0][0]="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),detRets1[0][0]))) - _detRets0_obj[maxId].append(res_real ) - _detRets0_obj = [_detRets0_obj[maxId]]##只输出有OCR的那个船名结果 - ocrInfo=detRets1[0][1] - print( ' _detRets0_obj:{} _detRets0_others:{} '.format( _detRets0_obj, _detRets0_others ) ) - rets=_detRets0_obj+_detRets0_others - t3=time.time() - outInfos='total:%.1f ,where det:%.1f, ocr:%s'%( (t3-t0)*1000, (t1-t0)*1000, ocrInfo) + # detRets1[0][0]="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),detRets1[0][0]))) + _detRets0_obj[maxId].append(res_real) + _detRets0_obj = [_detRets0_obj[maxId]] ##只输出有OCR的那个船名结果 + ocrInfo = detRets1[0][1] + print(' _detRets0_obj:{} _detRets0_others:{} '.format(_detRets0_obj, _detRets0_others)) + rets = _detRets0_obj + _detRets0_others + + if fiterList: + rets = filter_byClass(rets, fiterList) + if score_byClass: + rets = score_filter_byClass(rets, score_byClass) + t3 = time.time() + outInfos = 'total:%.1f ,where det:%.1f, ocr:%s' % ((t3 - t0) * 1000, (t1 - t0) * 1000, ocrInfo) #print('###line233:',detRets1,detRets0 ) return rets,outInfos -def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=True,device=' cuda:0',conf_thres=0.25, iou_thres=0.45,allowedList=[0,1,2,3], font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,trtFlag_det=False,SecNms=None): - #输入参数 +def AI_process_forest(im0s, model, segmodel, names, label_arraylist, rainbows, half=True, device=' cuda:0', + conf_thres=0.25, iou_thres=0.45, + font={'line_thickness': None, 'fontSize': None, 'boxLine_thickness': None, + 'waterLineColor': (0, 255, 255), 'waterLineWidth': 3}, trtFlag_det=False, + SecNms=None,ksize=None,score_byClass=None,fiterList=[]): + # 输入参数 # im0s---原始图像列表 # model---检测模型,segmodel---分割模型(如若没有用到,则为None) #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout @@ -313,7 +417,7 @@ def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=Tr # #strout---统计AI处理个环节的时间 # Letterbox - time0=time.time() + time0 = time.time() if trtFlag_det: img, padInfos = img_pad(im0s[0], size=(640,640,3)) ;img = [img] else: @@ -329,26 +433,22 @@ def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=Tr img = img.half() if half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 - if segmodel: - seg_pred,segstr = segmodel.eval(im0s[0] ) - segFlag=True - else: - seg_pred = None;segFlag=False - time1=time.time() - pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0] + + time1 = time.time() + pred = yolov5Trtforward(model, img) if trtFlag_det else model(img, augment=False)[0] + + p_result, timeOut = post_process_det(pred,padInfos,img,im0s,conf_thres,iou_thres,label_arraylist,rainbows,font,score_byClass,fiterList) + if segmodel and len(p_result[2])>0: + segmodel.set_image(im0s[0]) + p_result, timeOut = post_process_seg(im0s,segmodel,p_result[2],ksize) + + time2 = time.time() + time_info = 'letterbox:%.1f, infer:%.1f, ' % ((time1 - time0) * 1000, (time2 - time1) * 1000) + return p_result, time_info + timeOut - time2=time.time() - datas = [[''], img, im0s, None,pred,seg_pred,10] - - ObjectPar={ 'object_config':allowedList, 'slopeIndex':[] ,'segmodel':segFlag,'segRegionCnt':0 } - p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos,ovlap_thres=SecNms) - #print('###line274:',p_result[2]) - #p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,object_config=allowedList,segmodel=segFlag,font=font,padInfos=padInfos) - time_info = 'letterbox:%.1f, infer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 ) - return p_result,time_info+timeOut -def AI_det_track( im0s_in,modelPar,processPar,sort_tracker,segPar=None): - im0s,iframe=im0s_in[0],im0s_in[1] +def AI_det_track(im0s_in, modelPar, processPar, sort_tracker, segPar=None): + im0s, iframe = im0s_in[0], im0s_in[1] model = modelPar['det_Model'] segmodel = modelPar['seg_Model'] half,device,conf_thres, iou_thres,trtFlag_det = processPar['half'], processPar['device'], processPar['conf_thres'], processPar['iou_thres'],processPar['trtFlag_det'] @@ -705,17 +805,22 @@ def ocr_process(pars): info_str= ('pre-process:%.2f TRTforward:%.2f (%s) postProcess:%2.f decoder:%.2f, Total:%.2f , pred:%s'%(get_ms(time2,time1 ),get_ms(time3,time2 ),trtstr, get_ms(time4,time3 ), get_ms(time5,time4 ), get_ms(time5,time1 ), preds_str ) ) return preds_str,info_str -def AI_process_Ocr(im0s,modelList,device,detpar): +def AI_process_Ocr(im0s, modelList, device, detpar): timeMixPost = ':0 ms' new_device = torch.device(device) time0 = time.time() + img, padInfos = pre_process(im0s[0], new_device) ocrModel = modelList[1] time1 = time.time() - preds,timeOut = modelList[0].eval(img) + if not detpar['trtFlag_det']: + preds, timeOut = modelList[0].eval(img) + boxes = post_process(preds, padInfos, device, conf_thres=detpar['conf_thres'], iou_thres=detpar['iou_thres'], + nc=detpar['nc']) # 后处理 + else: + boxes, timeOut = modelList[0].eval(im0s[0]) time2 = time.time() - boxes = post_process(preds, padInfos, device, conf_thres=detpar['conf_thres'], iou_thres=detpar['iou_thres'], - nc=detpar['nc']) # 后处理 + imagePatches = [im0s[0][int(x[1]):int(x[3]), int(x[0]):int(x[2])] for x in boxes] detRets1 = [ocrModel.eval(patch) for patch in imagePatches] @@ -728,9 +833,9 @@ def AI_process_Ocr(im0s,modelList,device,detpar): dets.append([label, xyxy]) time_info = 'pre_process:%.1f, det:%.1f , ocr:%.1f ,timeMixPost:%s ' % ( - (time1 - time0) * 1000, (time2 - time1) * 1000, (time3 - time2) * 1000, timeMixPost) + (time1 - time0) * 1000, (time2 - time1) * 1000, (time3 - time2) * 1000, timeMixPost) - return [im0s[0],im0s[0],dets,0],time_info + return [im0s[0], im0s[0], dets, 0], time_info def AI_process_Crowd(im0s,model,device,postPar):