测试框架
|
|
@ -0,0 +1,285 @@
|
|||
from loguru import logger
|
||||
import cv2,os,time, json, glob
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import tensorrt as trt
|
||||
from models.experimental import attempt_load
|
||||
from DrGraph.util.masterUtils import get_needed_objectsIndex
|
||||
from DrGraph.util.stdc import stdcModel
|
||||
|
||||
from DrGraph.appIOs.conf.ModelTypeEnum import *
|
||||
from DrGraph.appIOs.conf.ModelUtils import *
|
||||
from DrGraph.util import aiHelper
|
||||
from DrGraph.util.drHelper import *
|
||||
|
||||
AnalysisFrameType = namedtuple('AnalysisFrameData', ['images', 'model', 'seg_model', 'names', 'label_arrays',
|
||||
'rainbows', 'object_params', 'font', 'image_name', 'seg_params', 'mode', 'post_params' ])
|
||||
|
||||
class BussinessBase:
|
||||
@staticmethod
|
||||
def createModel(opt):
|
||||
business = opt['business']
|
||||
if business == 'illParking':
|
||||
from .Bussiness_IllParking import Bussiness_IllParking
|
||||
return Bussiness_IllParking(opt)
|
||||
|
||||
def __init__(self, opt):
|
||||
self.bussiness = opt['business']
|
||||
from DrGraph.appIOs.conf.ModelUtils import MODEL_CONFIG
|
||||
self.code = '019'
|
||||
model_method = MODEL_CONFIG[self.code]
|
||||
self.modelClass = model_method[0]
|
||||
self.modelProcessFun = model_method[3]
|
||||
|
||||
self.param = {
|
||||
'device':'0', ###显卡号,如果用TRT模型,只支持0(单显卡)
|
||||
# 'labelnames':"../AIlib2/DrGraph/weights/conf/%s/labelnames.json" % (self.bussiness), ###检测类别对照表
|
||||
'labelnames':"../weights/conf/%s/labelnames.json" % (self.bussiness), ###检测类别对照表
|
||||
'max_workers':1, ###并行线程数
|
||||
'Detweights':"../weights/%s/yolov5_%s_fp16.engine"%(self.bussiness ,opt['gpu'] ),###检测模型路径
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,5,6,7,8,9] ],###控制哪些检测类别显示、输出
|
||||
'seg_nclass':4,###分割模型类别数目,默认2类
|
||||
'segRegionCnt':2,###分割模型结果需要保留的等值线数目
|
||||
'Segweights' : "../weights/%s/stdc_360X640_3090_fp16.engine" % (self.bussiness), ###分割模型权重位置
|
||||
'postFile': '../weights/conf/%s/para.json'%(self.bussiness),###后处理参数文件
|
||||
'txtFontSize':20,###文本字符的大小
|
||||
'digitFont': { 'line_thickness':2,'boxLine_thickness':1, 'fontSize':1.0,'waterLineColor':(0,255,255),'segLineShow':True,'waterLineWidth':2},###显示框、线设置
|
||||
'testImgPath':'./DrGraph/appIOs/samples/%s/' % (self.bussiness),###测试图像的位置
|
||||
'testOutPath':'./DrGraph/appIOs/results/%s/' % (self.bussiness),###输出测试图像位置
|
||||
'segPar': {'mixFunction':{'function':self.modelProcessFun, 'pars':{}}}
|
||||
}
|
||||
self.extraConfig(opt)
|
||||
|
||||
logger.warning(f"""[{self.bussiness}] 业务配置 - {[key for key in self.param if self.param[key] is not None]} - 重点配置:
|
||||
检测类别(labelnames):{self.param['labelnames']} >>>>>> {ioHelper.get_labelnames(self.param['labelnames'])}
|
||||
检测模型路径(Detweights): {self.param['Detweights']}
|
||||
分割模型权重文件(Segweights): {self.param['Segweights']}
|
||||
后处理参数文件(postFile): {self.param['postFile']}
|
||||
测试图像路径(testImgPath): {self.param['testImgPath']}
|
||||
输出图像位置(testOutPath): {self.param['testOutPath']}
|
||||
输出图像路径: {self.param['testOutPath']}""")
|
||||
ioHelper.checkFile(self.param['labelnames'], '检测类别')
|
||||
ioHelper.checkFile(self.param['Detweights'], '检测模型路径')
|
||||
ioHelper.checkFile(self.param['postFile'], '后处理参数文件')
|
||||
ioHelper.checkFile(self.param['Segweights'], '分割模型权重文件')
|
||||
ioHelper.checkFile(self.param['testImgPath'], '测试图像路径')
|
||||
if ioHelper.checkFile(self.param['testOutPath'], '输出图像路径') is False:
|
||||
os.makedirs(self.param['testOutPath'], exist_ok=True)
|
||||
ioHelper.checkFile(self.param['testOutPath'], '创建后再检查输出图像路径')
|
||||
|
||||
def extraConfig(self, opt):
|
||||
pass
|
||||
|
||||
def setParam(self, key, value):
|
||||
self.param[key] = value
|
||||
|
||||
def addParams(self, params):
|
||||
for key, value in params.items():
|
||||
self.param[key] = value
|
||||
|
||||
def getTestParam_Model(self):
|
||||
device = torchHelper.select_device(self.param['device']) # 1 device
|
||||
half = device.type != 'cpu'
|
||||
trtFlag_det=self.param['trtFlag_det']
|
||||
if trtFlag_det:
|
||||
Detweights = self.param['Detweights'] ##升级后的检测模型
|
||||
trt_logger = trt.Logger(trt.Logger.ERROR)
|
||||
with open(Detweights, "rb") as f, trt.Runtime(trt_logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
||||
logger.info(f"step {step}: 情况 1 - 成功载入 det model trt [{Detweights}]"); step += 1
|
||||
else:
|
||||
Detweights = self.param['Detweights']
|
||||
model = attempt_load(Detweights, map_location=device) # load FP32 model
|
||||
logger.info(f'step {step}: 情况 2 - 成功载入 det model pth [{Detweights}]'); step += 1
|
||||
if half:
|
||||
model.half() # 启用半精度推理
|
||||
return model
|
||||
|
||||
def getTestParam_SegModel(self):
|
||||
segmodel= None
|
||||
if self.param['Segweights']:
|
||||
if self.bussiness == 'cityMangement2':
|
||||
from DMPR import DMPRModel
|
||||
segmodel = DMPRModel(weights=self.param['Segweights'], par = self.param['segPar'])
|
||||
else:
|
||||
segmodel = stdcModel(weights=self.param['Segweights'], par = self.param['segPar'])
|
||||
else:
|
||||
logger.warning('############None seg model is loaded###########:' )
|
||||
return segmodel
|
||||
|
||||
def getTestParm_ObjectPar(self):
|
||||
from DrGraph.util import torchHelper
|
||||
device = torchHelper.select_device(self.param['device']) # 1 device
|
||||
half = device.type != 'cpu' # 2 half
|
||||
postFile= self.param['postFile']
|
||||
# 3 allowedList
|
||||
allowedList,allowedList_string=get_needed_objectsIndex(self.param['detModelpara'])
|
||||
# 4 segRegionCnt
|
||||
segRegionCnt=self.param['segRegionCnt']
|
||||
|
||||
if self.param['Segweights']:
|
||||
self.param['trtFlag_seg']=True if self.param['Segweights'].endswith('.engine') else False
|
||||
else:
|
||||
self.param['trtFlag_seg']=False
|
||||
self.param['trtFlag_det']=True if self.param['Detweights'].endswith('.engine') else False
|
||||
|
||||
trtFlag_det=self.param['trtFlag_det'] # 5 trtFlag_det
|
||||
trtFlag_seg=self.param['trtFlag_seg'] # 6 trtFlag_seg
|
||||
|
||||
detPostPar = ioHelper.get_postProcess_para_dic(postFile)
|
||||
# 7 conf_thres 8 iou_thres
|
||||
conf_thres,iou_thres,classes,rainbows = detPostPar["conf_thres"],detPostPar["iou_thres"],detPostPar["classes"],detPostPar["rainbows"]
|
||||
# 9 ovlap_thres_crossCategory
|
||||
if 'ovlap_thres_crossCategory' in detPostPar.keys():
|
||||
ovlap_thres_crossCategory=detPostPar['ovlap_thres_crossCategory']
|
||||
else:
|
||||
ovlap_thres_crossCategory = None
|
||||
# 10 score_byClass
|
||||
if 'score_byClass' in detPostPar.keys(): score_byClass=detPostPar['score_byClass']
|
||||
else: score_byClass = None
|
||||
|
||||
objectPar={
|
||||
'half':half,
|
||||
'device':device,
|
||||
'conf_thres':conf_thres,
|
||||
'ovlap_thres_crossCategory':ovlap_thres_crossCategory,
|
||||
'iou_thres':iou_thres,
|
||||
'allowedList':allowedList,
|
||||
'segRegionCnt':segRegionCnt,
|
||||
'trtFlag_det':trtFlag_det,
|
||||
'trtFlag_seg':trtFlag_seg ,
|
||||
'score_byClass':score_byClass}
|
||||
return objectPar
|
||||
def run(self):
|
||||
postFile= self.param['postFile']
|
||||
digitFont= self.param['digitFont']
|
||||
detPostPar = ioHelper.get_postProcess_para_dic(postFile)
|
||||
rainbows = detPostPar["rainbows"]
|
||||
|
||||
mode_paras=self.param['detModelpara']
|
||||
allowedList,allowedList_string=get_needed_objectsIndex(mode_paras)
|
||||
requestId = '1234'
|
||||
gpu_name = '3090'
|
||||
base_dir = None
|
||||
env = None
|
||||
from GPUtil import getAvailable, getGPUs
|
||||
gpu_ids = getAvailable(maxLoad=0.80, maxMemory=0.80)
|
||||
modelObject = self.modelClass(gpu_ids[0], allowedList, requestId, gpu_name, base_dir, env)
|
||||
model_conf = modelObject.model_conf
|
||||
model_param = model_conf[1]
|
||||
if 'model' not in model_param:
|
||||
model_param['model'] = self.getTestParam_Model()
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 model - 置为测试配置量")
|
||||
if 'segmodel' not in model_param:
|
||||
model_param['segmodel'] = self.getTestParam_SegModel()
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 segmodel - {model_param}")
|
||||
if 'objectPar' not in model_param:
|
||||
model_param['objectPar'] = self.getTestParm_ObjectPar()
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 objectPar - {model_param}")
|
||||
if 'segPar' not in model_param:
|
||||
self.param['segPar']['seg_nclass'] = self.param['seg_nclass']
|
||||
model_param['segPar']=self.param['segPar']
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 segPar - {model_param}")
|
||||
if 'mode' not in model_param:
|
||||
model_param['mode'] = self.param['mode'] if 'mode' in self.param.keys() else 'others'
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 mode - 置为测试配置量{model_param['mode']}")
|
||||
if 'postPar' not in model_param:
|
||||
model_param['postPar'] = self.param['postPar'] if 'postPar' in self.param.keys() else None
|
||||
logger.error(f"[{self.bussiness}] 业务配置 - 缺少模型参数 postPar - 置为测试配置量{model_param['postPar']}")
|
||||
|
||||
labelnames = self.param['labelnames']
|
||||
names = ioHelper.get_labelnames(labelnames)
|
||||
label_arraylist = imgHelper.get_label_arrays(names,rainbows,outfontsize=self.param['txtFontSize'],fontpath="./DrGraph/appIOs/conf/platech.ttf")
|
||||
|
||||
max_workers=self.param['max_workers']
|
||||
|
||||
# 获取测试图像和视频路径
|
||||
impth = self.param['testImgPath']
|
||||
outpth = self.param['testOutPath']
|
||||
imgpaths=[]###获取文件里所有的图像
|
||||
for postfix in ['.jpg','.JPG','.PNG','.png']:
|
||||
imgpaths.extend(glob.glob('%s/*%s'%(impth,postfix )) )
|
||||
videopaths=[]###获取文件里所有的视频
|
||||
for postfix in ['.MP4','.mp4','.avi']:
|
||||
videopaths.extend(glob.glob('%s/*%s'%(impth,postfix )) )
|
||||
|
||||
# 构造图像帧处理对象列表
|
||||
frames=[]
|
||||
for imgpath in imgpaths:
|
||||
im0s=[cv2.imread(imgpath)]
|
||||
analysisFrameData = AnalysisFrameType(
|
||||
im0s,
|
||||
model_param['model'], # model
|
||||
model_param['segmodel'], # segmodel,
|
||||
names,
|
||||
label_arraylist,
|
||||
rainbows,
|
||||
model_param['objectPar'], # objectPar,
|
||||
digitFont,
|
||||
os.path.basename(imgpath),
|
||||
model_param['segPar'], # segPar,
|
||||
model_param['mode'], # mode,
|
||||
model_param['postPar'] # postPar
|
||||
)
|
||||
# im0s,model,segmodel,names,label_arraylist,rainbows,objectPar,digitFont,os.path.basename(imgpath),segPar,mode,postPar)
|
||||
frames.append(analysisFrameData)
|
||||
logger.info(f'共读入 %d 张图片待处理' % len(imgpaths));
|
||||
t1=time.time()
|
||||
# 多线程或单线程处理图像
|
||||
if max_workers==1:
|
||||
for index, img in enumerate(frames):
|
||||
logger.warning(f'-'*20 + ' 处理图片 ' + imgpaths[index] + '-'*20);
|
||||
t5=time.time()
|
||||
self.doAnalysis(img)
|
||||
t6=time.time()
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as t:
|
||||
for result in t.map(self.doAnalysis, frames):
|
||||
t=result
|
||||
|
||||
t2=time.time()
|
||||
if len(imgpaths)>0:
|
||||
logger.info('%d 张图片共耗时:%.1f ms ,依次为:%.1f ms, 占用 %d 线程'%(len(imgpaths),(t2-t1)*1000, (t2-t1)*1000.0/len(imgpaths) , max_workers) );
|
||||
|
||||
def doAnalysis(self, frameData: AnalysisFrameType):
|
||||
time00 = time.time()
|
||||
H,W,C = frameData[0][0].shape
|
||||
#frmess---- (im0s,model,segmodel,names,label_arraylist,rainbows,objectPar,digitFont,os.path.basename(imgpath),segPar,mode,postPar)
|
||||
#p_result[1] = draw_painting_joint(xyxy,p_result[1],label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],font=font,socre_location="leftBottom")
|
||||
|
||||
with TimeDebugger('业务分析') as td:
|
||||
p_result, timeOut = aiHelper.AI_process(frameData.images, frameData.model, frameData.seg_model,
|
||||
frameData.names, frameData.label_arrays, frameData.rainbows,
|
||||
objectPar=frameData.object_params, font=frameData.font,
|
||||
segPar=frameData.seg_params, mode=frameData.mode, postPar=frameData.post_params)
|
||||
td.addStep('AI_Process')
|
||||
p_result[1] = drawHelper.drawAllBox(p_result[2],p_result[1],frameData[4],frameData[5],frameData[7])
|
||||
td.addStep('drawAllBox')
|
||||
# time11 = time.time()
|
||||
image_array = p_result[1]
|
||||
|
||||
cv2.imwrite(os.path.join(self.param['testOutPath'], frameData[8] ) ,image_array)
|
||||
bname = frameData[8].split('.')[0]
|
||||
if frameData[2]:
|
||||
if len(p_result)==5:
|
||||
image_mask = p_result[4]
|
||||
if isinstance(image_mask,np.ndarray) and image_mask.shape[0]>0:
|
||||
cv2.imwrite(os.path.join(self.param['testOutPath'],bname+'_mask.png' ) , (image_mask).astype(np.uint8))
|
||||
td.addStep('testOutPath')
|
||||
boxes=p_result[2]
|
||||
with open(os.path.join(self.param['testOutPath'], bname+'.txt' ),'w' ) as fp:
|
||||
for box in boxes:
|
||||
box_str=[str(x) for x in box]
|
||||
out_str=','.join(box_str)+'\n'
|
||||
fp.write(out_str)
|
||||
td.addStep('fp')
|
||||
# time22 = time.time()
|
||||
logger.info(td.getReportInfo())
|
||||
# logger.info('''耗时记录分析:
|
||||
# 原始图像:%s,%d*%d
|
||||
# AI-process: %.1f,其中:
|
||||
# image save:%.1f %s'''%(frameData[8],H,W, \
|
||||
# (time11 - time00) * 1000.0, (time22-time11)*1000.0,timeOut))
|
||||
return 'success'
|
||||
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
from loguru import logger
|
||||
import cv2, time
|
||||
import numpy as np
|
||||
|
||||
from DrGraph.util.drHelper import *
|
||||
from .Bussiness import BussinessBase
|
||||
|
||||
class Bussiness_IllParking(BussinessBase):
|
||||
def __init__(self, opt):
|
||||
logger.info("create AlAlg_IllParking")
|
||||
super().__init__(opt)
|
||||
|
||||
@staticmethod
|
||||
def postProcess(pred, cvMask, pars):
|
||||
#pred:直接预测结果,不要原图。预测结果[0,1,2,...],不是[车、T角点,L角点]
|
||||
#mask_cv:分割结果图,numpy格式(H,W),结果是int,[0,1,2,...]
|
||||
#pars: 其它参数,dict格式
|
||||
'''三个标签:车、T角点,L角点'''
|
||||
'''输入:落水人员的结果(类别+坐标)、原图
|
||||
|
||||
过程:将车辆识别框外扩,并按contours形成区域。
|
||||
T角点与L角点的坐标合并为列表。
|
||||
判断每个车辆contours区域内有几个角点,少于2个则判断违停。
|
||||
返回:最终违停车辆标记结果图、违停车辆信息(坐标、类别、置信度)。
|
||||
'''
|
||||
#输入的是[cls,x0,y0,x1,y1,score]---> [x0,y0,x1,y1,cls,score]
|
||||
#输出的也是[cls,x0,y0,x1,y1,score]
|
||||
#pred = [ [ int(x[4]) ,*x[1:5], x[5] ] for x in pred]
|
||||
|
||||
#pred = [[ *x[1:5],x[0], x[5] ] for x in pred]
|
||||
pred = [[ *x[0:4],x[5], x[4] ] for x in pred]
|
||||
|
||||
##统一格式
|
||||
imgSize=pars['imgSize']
|
||||
'''1、pred中车辆识别框形成列表,T角点与L角点形成列表'''
|
||||
tW1=time.time()
|
||||
init_vehicle=[]
|
||||
init_corner = []
|
||||
|
||||
for i in range(len(pred)):
|
||||
#if pred[i][4]=='TCorner' or pred[i][4]=='LCorner': #vehicle、TCorner、LCorner
|
||||
if pred[i][4]==1 or pred[i][4]==2: #vehicle、TCorner、LCorner
|
||||
init_corner.append(pred[i])
|
||||
else:
|
||||
init_vehicle.append(pred[i])
|
||||
|
||||
'''2、init_corner中心点坐标计算,并形成列表。'''
|
||||
tW2 = time.time()
|
||||
center_corner=[]
|
||||
for i in range(len(init_corner)):
|
||||
center_corner.append(mathHelper.center_coordinate(init_corner[i]))
|
||||
|
||||
|
||||
'''3、遍历每个车辆识别框,扩充矩形区域,将矩形区域形成contours,判断扩充区域内的。'''
|
||||
tW3 = time.time()
|
||||
final_weiting=[] #违停车辆列表
|
||||
'''遍历车辆列表,扩大矩形框形成contours'''
|
||||
for i in range(len(init_vehicle)):
|
||||
boundbxs1=[init_vehicle[i][0],init_vehicle[i][1],init_vehicle[i][2],init_vehicle[i][3]]
|
||||
width_boundingbox=init_vehicle[i][2]-init_vehicle[i][0] #框宽度
|
||||
height_boundingbox=init_vehicle[i][2] - init_vehicle[i][0] #框长度
|
||||
#当框长大于宽,则是水平方向车辆;否则认为是竖向车辆
|
||||
if width_boundingbox>=height_boundingbox:
|
||||
ex_width=0.4*(init_vehicle[i][2]-init_vehicle[i][0]) #矩形扩充宽度,取车宽0.4倍 #膨胀系数小一些。角点设成1个。
|
||||
ex_height=0.2*(init_vehicle[i][2]-init_vehicle[i][0]) #矩形扩充宽度,取车长0.2倍
|
||||
boundbxs1 = imgHelper.expand_rectangle(boundbxs1, imgSize, ex_width, ex_height) # 扩充后矩形对角坐标
|
||||
else:
|
||||
ex_width=0.2*(init_vehicle[i][2]-init_vehicle[i][0]) #竖向,不需要改变变量名称,将系数对换下就行。(坐标点顺序还是1234不变)
|
||||
ex_height=0.4*(init_vehicle[i][2]-init_vehicle[i][0]) #
|
||||
boundbxs1 = imgHelper.expand_rectangle(boundbxs1, imgSize, ex_width, ex_height) # 扩充后矩形对角坐标
|
||||
contour_temp = mathHelper.fourcorner_coordinate(boundbxs1) #得到扩充后矩形框的contour
|
||||
contour_temp_=np.array(contour_temp)#contour转为array
|
||||
contour_temp_=np.float32(contour_temp_)
|
||||
|
||||
'''遍历角点识别框中心坐标是否在contours内,在则计1'''
|
||||
zzz=0
|
||||
for j in range(len(center_corner)):
|
||||
flag = cv2.pointPolygonTest(contour_temp_, (center_corner[j][0], center_corner[j][1]), False) #若为False,会找点是否在内,外,或轮廓上(相应返回+1, -1, 0)。
|
||||
if flag==+1:
|
||||
zzz+=1
|
||||
'''contours框内小于等于1个角点,认为不在停车位内'''
|
||||
# if zzz<=1:
|
||||
if zzz<1:
|
||||
final_weiting.append(init_vehicle[i])
|
||||
#print('t7-t6',t7-t6)
|
||||
#print('final_weiting',final_weiting)
|
||||
|
||||
'''4、绘制保存检违停车辆图像'''
|
||||
|
||||
tW4=time.time()
|
||||
'''
|
||||
colors = Colors()
|
||||
if final_weiting is not None:
|
||||
for i in range(len(final_weiting)):
|
||||
lbl='illegal park'
|
||||
xyxy=[final_weiting[i][0],final_weiting[i][1],final_weiting[i][2],final_weiting[i][3]]
|
||||
c = int(5)
|
||||
plot_one_box(xyxy, _img_cv, label=lbl, color=colors(c, True), line_thickness=3)
|
||||
final_img=_img_cv
|
||||
'''
|
||||
tW5=time.time()
|
||||
# cv2.imwrite('final_result.png', _img_cv)
|
||||
|
||||
|
||||
timeStr = ' step1:%s step2:%s step3:%s save:%s'%(\
|
||||
timeHelper.deltaTimeString_MS(tW2,tW1), \
|
||||
timeHelper.deltaTimeString_MS(tW3,tW2), \
|
||||
timeHelper.deltaTimeString_MS(tW4,tW3), \
|
||||
timeHelper.deltaTimeString_MS(tW5,tW4) )
|
||||
|
||||
#final_weiting-----[x0,y0,x1,y1,cls,score]
|
||||
#输出的也是outRe----[cls,x0,y0,x1,y1,score]
|
||||
|
||||
#outRes = [ [ 3 ,*x[0:4], x[5] ] for x in final_weiting]###违停用3表示
|
||||
|
||||
outRes = [ [ *x[0:4], x[5],3 ] for x in final_weiting]###违停用3表示
|
||||
|
||||
return outRes,timeStr #返回最终绘制的结果图、违停车辆(坐标、类别、置信度)
|
||||
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from loguru import logger
|
||||
import time
|
||||
import tensorrt as trt
|
||||
from DMPR import DMPRModel
|
||||
from traceback import format_exc
|
||||
from models.experimental import attempt_load
|
||||
|
||||
from DrGraph.util.drHelper import *
|
||||
from DrGraph.util.Constant import *
|
||||
from DrGraph.enums.ExceptionEnum import ExceptionType
|
||||
from DrGraph.util.stdc import stdcModel
|
||||
|
||||
# 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型
|
||||
class Model1:
|
||||
__slots__ = "model_conf"
|
||||
# 3090
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||
try:
|
||||
start = time.time()
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
logger.info('__init__(device={}, allowedList={}, requestId={}, modeType={}, gpu_name={}, base_dir={}, env={})', \
|
||||
device, allowedList, requestId, modeType, gpu_name, base_dir, env)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
|
||||
names = par['labelnames']
|
||||
postFile = par['postFile']
|
||||
rainbows = postFile["rainbows"]
|
||||
new_device = torchHelper.select_device(par.get('device'))
|
||||
half = new_device.type != 'cpu'
|
||||
Detweights = par['Detweights']
|
||||
if par['trtFlag_det']:
|
||||
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())
|
||||
else:
|
||||
model = attempt_load(Detweights, map_location=new_device) # load FP32 model
|
||||
if half: model.half()
|
||||
par['segPar']['seg_nclass'] = par['seg_nclass']
|
||||
Segweights = par['Segweights']
|
||||
if Segweights:
|
||||
if modeType.value[3] == 'cityMangement3':
|
||||
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
|
||||
else:
|
||||
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
|
||||
else:
|
||||
segmodel = None
|
||||
objectPar = {
|
||||
'half': half,
|
||||
'device': new_device,
|
||||
'conf_thres': postFile["conf_thres"],
|
||||
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
|
||||
'iou_thres': postFile["iou_thres"],
|
||||
# 对高速模型进行过滤
|
||||
'segRegionCnt': par['segRegionCnt'],
|
||||
'trtFlag_det': par['trtFlag_det'],
|
||||
'trtFlag_seg': par['trtFlag_seg'],
|
||||
'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
||||
'fiterList': par['fiterList'] if 'fiterList' in par.keys() else []
|
||||
}
|
||||
model_param = {
|
||||
"model": model,
|
||||
"segmodel": segmodel,
|
||||
"objectPar": objectPar,
|
||||
"segPar": segPar,
|
||||
"mode": mode,
|
||||
"postPar": postPar
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId)
|
||||
|
||||
|
|
@ -0,0 +1,785 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from pickle import dumps, loads
|
||||
from traceback import format_exc
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
from loguru import logger
|
||||
|
||||
from DrGraph.util.drHelper import *
|
||||
from DrGraph.util import aiHelper
|
||||
from DrGraph.util.Constant import *
|
||||
from DrGraph.enums.ExceptionEnum import ExceptionType
|
||||
|
||||
from .ModelTypeEnum import ModelType
|
||||
from DrGraph.util.PlotsUtils import get_label_arrays
|
||||
|
||||
sys.path.extend(['..', '../AIlib2'])
|
||||
FONT_PATH = "./DrGraph/appIOs/conf/platech.ttf"
|
||||
|
||||
from DrGraph.Bussiness.Models import *
|
||||
|
||||
MODEL_CONFIG = {
|
||||
# 车辆违停模型
|
||||
ModelType.ILLPARKING_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: Model1(x, y, r, ModelType.ILLPARKING_MODEL, t, z, h),
|
||||
ModelType.ILLPARKING_MODEL,
|
||||
lambda x, y, z: one_label(x, y, z), # MODEL_CONFIG[code][2]
|
||||
lambda x: model_process(x)
|
||||
),
|
||||
}
|
||||
|
||||
# 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型
|
||||
class OneModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
# 3090
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||
try:
|
||||
start = time.time()
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
logger.info('__init__(device={}, allowedList={}, requestId={}, modeType={}, gpu_name={}, base_dir={}, env={})', \
|
||||
device, allowedList, requestId, modeType, gpu_name, base_dir, env)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
|
||||
names = par['labelnames']
|
||||
postFile = par['postFile']
|
||||
rainbows = postFile["rainbows"]
|
||||
new_device = torchHelper.select_device(par.get('device'))
|
||||
half = new_device.type != 'cpu'
|
||||
Detweights = par['Detweights']
|
||||
if par['trtFlag_det']:
|
||||
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())
|
||||
else:
|
||||
model = attempt_load(Detweights, map_location=new_device) # load FP32 model
|
||||
if half: model.half()
|
||||
par['segPar']['seg_nclass'] = par['seg_nclass']
|
||||
Segweights = par['Segweights']
|
||||
if Segweights:
|
||||
if modeType.value[3] == 'cityMangement3':
|
||||
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
|
||||
else:
|
||||
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
|
||||
else:
|
||||
segmodel = None
|
||||
objectPar = {
|
||||
'half': half,
|
||||
'device': new_device,
|
||||
'conf_thres': postFile["conf_thres"],
|
||||
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
|
||||
'iou_thres': postFile["iou_thres"],
|
||||
# 对高速模型进行过滤
|
||||
'segRegionCnt': par['segRegionCnt'],
|
||||
'trtFlag_det': par['trtFlag_det'],
|
||||
'trtFlag_seg': par['trtFlag_seg'],
|
||||
'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
||||
'fiterList': par['fiterList'] if 'fiterList' in par.keys() else []
|
||||
}
|
||||
model_param = {
|
||||
"model": model,
|
||||
"segmodel": segmodel,
|
||||
"objectPar": objectPar,
|
||||
"segPar": segPar,
|
||||
"mode": mode,
|
||||
"postPar": postPar
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId)
|
||||
|
||||
# 纯分类模型
|
||||
class cityManagementModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
postProcess = par['postProcess']
|
||||
names = par['labelnames']
|
||||
postFile = par['postFile']
|
||||
rainbows = postFile["rainbows"]
|
||||
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
model_param = {
|
||||
"modelList": modelList,
|
||||
"postProcess": postProcess,
|
||||
"score_byClass":par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
||||
"fiterList":par['fiterList'] if 'fiterList' in par.keys() else [],
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
def detSeg_demo2(args):
|
||||
model_conf, frame, request_id = args
|
||||
modelList, postProcess,score_byClass,fiterList = (
|
||||
model_conf[1]['modelList'], model_conf[1]['postProcess'],model_conf[1]['score_byClass'], model_conf[1]['fiterList'])
|
||||
try:
|
||||
result = [[ None, None, AI_process_N([frame], modelList, postProcess,score_byClass,fiterList)[0] ] ] # 为了让返回值适配统一的接口而写的shi
|
||||
return result
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
def model_process(args):
|
||||
model_conf, frame, request_id = args
|
||||
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
|
||||
try:
|
||||
return aiHelper.AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
|
||||
rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
|
||||
segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
|
||||
postPar=model_param['postPar'])
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
# self.num += 1
|
||||
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
|
||||
# 森林模型、车辆模型、行人模型、烟火模型、 钓鱼模型、航道模型、乡村模型、城管模型公共模型
|
||||
class TwoModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
s = time.time()
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device1), gpu_name)
|
||||
device = select_device(par.get('device'))
|
||||
names = par['labelnames']
|
||||
half = device.type != 'cpu'
|
||||
Detweights = par['Detweights']
|
||||
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(f.read())
|
||||
if modeType == ModelType.CITY_FIREAREA_MODEL:
|
||||
sam = sam_model_registry[par['sam_type']](checkpoint=par['Samweights'])
|
||||
sam.to(device=device)
|
||||
segmodel = SamPredictor(sam)
|
||||
else:
|
||||
segmodel = None
|
||||
|
||||
postFile = par['postFile']
|
||||
conf_thres = postFile["conf_thres"]
|
||||
iou_thres = postFile["iou_thres"]
|
||||
rainbows = postFile["rainbows"]
|
||||
otc = postFile.get("ovlap_thres_crossCategory")
|
||||
model_param = {
|
||||
"model": model,
|
||||
"segmodel": segmodel,
|
||||
"half": half,
|
||||
"device": device,
|
||||
"conf_thres": conf_thres,
|
||||
"iou_thres": iou_thres,
|
||||
"trtFlag_det": par['trtFlag_det'],
|
||||
"otc": otc,
|
||||
"ksize":par['ksize'] if 'ksize' in par.keys() else None,
|
||||
"score_byClass": par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
||||
"fiterList": par['fiterList'] if 'fiterList' in par.keys() else []
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
|
||||
def forest_process(args):
|
||||
model_conf, frame, request_id = args
|
||||
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
|
||||
try:
|
||||
return AI_process_forest([frame], model_param['model'], model_param['segmodel'], names,
|
||||
model_param['label_arraylist'], rainbows, model_param['half'], model_param['device'],
|
||||
model_param['conf_thres'], model_param['iou_thres'],font=model_param['digitFont'],
|
||||
trtFlag_det=model_param['trtFlag_det'], SecNms=model_param['otc'],ksize = model_param['ksize'],
|
||||
score_byClass=model_param['score_byClass'],fiterList=model_param['fiterList'])
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
# self.num += 1
|
||||
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
class MultiModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
s = time.time()
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device1), gpu_name)
|
||||
postProcess = par['postProcess']
|
||||
names = par['labelnames']
|
||||
postFile = par['postFile']
|
||||
rainbows = postFile["rainbows"]
|
||||
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
model_param = {
|
||||
"modelList": modelList,
|
||||
"postProcess": postProcess,
|
||||
"score_byClass": par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
||||
"fiterList": par['fiterList'] if 'fiterList' in par.keys() else []
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
|
||||
def channel2_process(args):
|
||||
model_conf, frame, request_id = args
|
||||
modelList, postProcess,score_byClass,fiterList = (
|
||||
model_conf[1]['modelList'], model_conf[1]['postProcess'],model_conf[1]['score_byClass'], model_conf[1]['fiterList'])
|
||||
try:
|
||||
start = time.time()
|
||||
result = [[None, None, AI_process_C([frame], modelList, postProcess,score_byClass,fiterList)[0]]] # 为了让返回值适配统一的接口而写的shi
|
||||
# print("AI_process_C use time = {}".format(time.time()-start))
|
||||
return result
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
def get_label_arraylist(*args):
|
||||
width, height, names, rainbows = args
|
||||
# line = int(round(0.002 * (height + width) / 2) + 1)
|
||||
line = max(1, int(round(width / 1920 * 3)))
|
||||
label = ' 0.95'
|
||||
tf = max(line - 1, 1)
|
||||
fontScale = line * 0.33
|
||||
text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
|
||||
# fontsize = int(width / 1920 * 40)
|
||||
numFontSize = float(format(width / 1920 * 1.1, '.1f'))
|
||||
digitFont = {'line_thickness': line,
|
||||
'boxLine_thickness': line,
|
||||
'fontSize': numFontSize,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'segLineShow': False,
|
||||
'waterLineWidth': line,
|
||||
'wordSize': text_height,
|
||||
'label_location': 'leftTop'}
|
||||
label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
|
||||
return digitFont, label_arraylist, (line, text_width, text_height, fontScale, tf)
|
||||
# 船只模型
|
||||
class ShipModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
s = time.time()
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device1), gpu_name)
|
||||
model, decoder2 = load_model_decoder_OBB(par)
|
||||
par['decoder'] = decoder2
|
||||
names = par['labelnames']
|
||||
rainbows = par['postFile']["rainbows"]
|
||||
model_param = {
|
||||
"model": model,
|
||||
"par": par
|
||||
}
|
||||
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
||||
except Exception:
|
||||
logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
|
||||
def obb_process(args):
|
||||
model_conf, frame, request_id = args
|
||||
model_param = model_conf[1]
|
||||
# font_config, frame, names, label_arrays, rainbows, model, par, requestId = args
|
||||
try:
|
||||
return OBB_infer(model_param["model"], frame, model_param["par"])
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
# self.num += 1
|
||||
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
# 车牌分割模型、健康码、行程码分割模型
|
||||
class IMModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
img_type = 'code'
|
||||
if ModelType.PLATE_MODEL == modeType:
|
||||
img_type = 'plate'
|
||||
par = {
|
||||
'code': {'weights': '../weights/pth/AIlib2/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
|
||||
'plate': {'weights': '../weights/pth/AIlib2/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
|
||||
'conf_thres': 0.4,
|
||||
'iou_thres': 0.45,
|
||||
'device': 'cuda:%s' % device,
|
||||
'plate_dilate': (0.5, 0.3)
|
||||
}
|
||||
|
||||
new_device = torch.device(par['device'])
|
||||
model = torch.jit.load(par[img_type]['weights'])
|
||||
logger.info("########################加载 jit 模型成功 成功 ########################, requestId:{}",
|
||||
requestId)
|
||||
self.model_conf = (modeType, allowedList, new_device, model, par, img_type)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
def im_process(args):
|
||||
frame, device, model, par, img_type, requestId = args
|
||||
try:
|
||||
img, padInfos = pre_process(frame, device)
|
||||
pred = model(img)
|
||||
boxes = post_process(pred, padInfos, device, conf_thres=par['conf_thres'],
|
||||
iou_thres=par['iou_thres'], nc=par[img_type]['nc']) # 后处理
|
||||
dataBack = get_return_data(frame, boxes, modelType=img_type, plate_dilate=par['plate_dilate'])
|
||||
print('-------line351----:',dataBack)
|
||||
return dataBack
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
def immulti_process(args):
|
||||
model_conf, frame, requestId = args
|
||||
device, modelList, detpar = model_conf[1], model_conf[2], model_conf[3]
|
||||
try:
|
||||
# new_device = torch.device(device)
|
||||
# img, padInfos = pre_process(frame, new_device)
|
||||
# pred = model(img)
|
||||
# boxes = post_process(pred, padInfos, device, conf_thres=pardet['conf_thres'],
|
||||
# iou_thres=pardet['iou_thres'], nc=pardet['nc']) # 后处理
|
||||
return AI_process_Ocr([frame], modelList, device, detpar)
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
class CARPLATEModel:
|
||||
__slots__ = "model_conf"
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
detpar = par['models'][0]['par']
|
||||
# new_device = torch.device(par['device'])
|
||||
# modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
self.model_conf = (modeType, device, modelList, detpar, par['rainbows'])
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
class DENSECROWDCOUNTModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
rainbows = par["rainbows"]
|
||||
models=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
postPar = [pp['par'] for pp in par['models']]
|
||||
self.model_conf = (modeType, device, models, postPar, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
def cc_process(args):
|
||||
model_conf, frame, requestId = args
|
||||
device, model, postPar = model_conf[1], model_conf[2], model_conf[3]
|
||||
try:
|
||||
return AI_process_Crowd([frame], model, device, postPar)
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
|
||||
# # 百度AI图片识别模型
|
||||
# class BaiduAiImageModel:
|
||||
# __slots__ = "model_conf"
|
||||
|
||||
# def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
# env=None):
|
||||
# try:
|
||||
# logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
# requestId)
|
||||
# # 人体检测与属性识别、 人流量统计客户端
|
||||
# aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir, env)
|
||||
# # 车辆检测检测客户端
|
||||
# aipImageClassifyClient = AipImageClassifyClient(base_dir, env)
|
||||
# rainbows = COLOR
|
||||
# vehicle_names = [VehicleEnum.CAR.value[1], VehicleEnum.TRICYCLE.value[1], VehicleEnum.MOTORBIKE.value[1],
|
||||
# VehicleEnum.CARPLATE.value[1], VehicleEnum.TRUCK.value[1], VehicleEnum.BUS.value[1]]
|
||||
# person_names = ['人']
|
||||
# self.model_conf = (modeType, aipImageClassifyClient, aipBodyAnalysisClient, allowedList, rainbows,
|
||||
# vehicle_names, person_names, requestId)
|
||||
# except Exception:
|
||||
# logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
# raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
# ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
|
||||
# def get_baidu_label_arraylist(*args):
|
||||
# width, height, vehicle_names, person_names, rainbows = args
|
||||
# # line = int(round(0.002 * (height + width) / 2) + 1)
|
||||
# line = max(1, int(round(width / 1920 * 3) + 1))
|
||||
# label = ' 0.97'
|
||||
# tf = max(line, 1)
|
||||
# fontScale = line * 0.33
|
||||
# text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
|
||||
# vehicle_label_arrays = get_label_arrays(vehicle_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
|
||||
# person_label_arrays = get_label_arrays(person_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
|
||||
# font_config = (line, text_width, text_height, fontScale, tf)
|
||||
# return vehicle_label_arrays, person_label_arrays, font_config
|
||||
|
||||
|
||||
# def baidu_process(args):
|
||||
# target, url, aipImageClassifyClient, aipBodyAnalysisClient, request_id = args
|
||||
# try:
|
||||
# # [target, url, aipImageClassifyClient, aipBodyAnalysisClient, requestId]
|
||||
# baiduEnum = BAIDU_MODEL_TARGET_CONFIG.get(target)
|
||||
# if baiduEnum is None:
|
||||
# raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
|
||||
# ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
|
||||
# + " target: " + target)
|
||||
# return baiduEnum.value[2](aipImageClassifyClient, aipBodyAnalysisClient, url, request_id)
|
||||
# except ServiceException as s:
|
||||
# raise s
|
||||
# except Exception:
|
||||
# logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
# raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
# ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
|
||||
def one_label(width, height, model_conf):
|
||||
# modeType, model_param, allowedList, names, rainbows = model_conf
|
||||
names = model_conf[3]
|
||||
rainbows = model_conf[4]
|
||||
model_param = model_conf[1]
|
||||
digitFont, label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
|
||||
model_param['digitFont'] = digitFont
|
||||
model_param['label_arraylist'] = label_arraylist
|
||||
model_param['font_config'] = font_config
|
||||
|
||||
# def dynamics_label(width, height, model_conf):
|
||||
# # modeType, model_param, allowedList, names, rainbows = model_conf
|
||||
# names = model_conf[3]
|
||||
# rainbows = model_conf[4]
|
||||
# model_param = model_conf[1]
|
||||
# digitFont, label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
|
||||
# line = max(1, int(round(width / 1920 * 3)))
|
||||
# label = ' 0.95'
|
||||
# tf = max(line - 1, 1)
|
||||
# fontScale = line * 0.33
|
||||
# _, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
|
||||
# label_dict = get_label_array_dict(rainbows, fontSize=text_height, fontPath=FONT_PATH)
|
||||
# model_param['digitFont'] = digitFont
|
||||
# model_param['label_arraylist'] = label_arraylist
|
||||
# model_param['font_config'] = font_config
|
||||
# model_param['label_dict'] = label_dict
|
||||
# def baidu_label(width, height, model_conf):
|
||||
# # modeType, aipImageClassifyClient, aipBodyAnalysisClient, allowedList, rainbows,
|
||||
# # vehicle_names, person_names, requestId
|
||||
# vehicle_names = model_conf[5]
|
||||
# person_names = model_conf[6]
|
||||
# rainbows = model_conf[4]
|
||||
# vehicle_label_arrays, person_label_arrays, font_config = get_baidu_label_arraylist(width, height, vehicle_names,
|
||||
# person_names, rainbows)
|
||||
# return vehicle_label_arrays, person_label_arrays, font_config
|
||||
|
||||
|
||||
# MODEL_CONFIG = {
|
||||
# # 车辆违停模型
|
||||
# ModelType.ILLPARKING_MODEL.value[1]: (
|
||||
# lambda x, y, r, t, z, h: Model1(x, y, r, ModelType.ILLPARKING_MODEL, t, z, h),
|
||||
# ModelType.ILLPARKING_MODEL,
|
||||
# lambda x, y, z: one_label(x, y, z), # MODEL_CONFIG[code][2]
|
||||
# lambda x: model_process(x)
|
||||
# ),
|
||||
# # # 加载河道模型
|
||||
# # ModelType.WATER_SURFACE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.WATER_SURFACE_MODEL, t, z, h),
|
||||
# # ModelType.WATER_SURFACE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 加载森林模型
|
||||
# # # ModelType.FOREST_FARM_MODEL.value[1]: (
|
||||
# # # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.FOREST_FARM_MODEL, t, z, h),
|
||||
# # # ModelType.FOREST_FARM_MODEL,
|
||||
# # # lambda x, y, z: one_label(x, y, z),
|
||||
# # # lambda x: forest_process(x)
|
||||
# # # ),
|
||||
# # ModelType.FOREST_FARM_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FOREST_FARM_MODEL, t, z, h),
|
||||
# # ModelType.FOREST_FARM_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
|
||||
# # # 加载交通模型
|
||||
# # ModelType.TRAFFIC_FARM_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h),
|
||||
# # ModelType.TRAFFIC_FARM_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 加载防疫模型
|
||||
# # ModelType.EPIDEMIC_PREVENTION_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType.EPIDEMIC_PREVENTION_MODEL, t, z, h),
|
||||
# # ModelType.EPIDEMIC_PREVENTION_MODEL,
|
||||
# # None,
|
||||
# # lambda x: im_process(x)),
|
||||
# # # 加载车牌模型
|
||||
# # ModelType.PLATE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType.PLATE_MODEL, t, z, h),
|
||||
# # ModelType.PLATE_MODEL,
|
||||
# # None,
|
||||
# # lambda x: im_process(x)),
|
||||
# # # 加载车辆模型
|
||||
# # ModelType.VEHICLE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.VEHICLE_MODEL, t, z, h),
|
||||
# # ModelType.VEHICLE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)
|
||||
# # ),
|
||||
# # # 加载行人模型
|
||||
# # ModelType.PEDESTRIAN_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.PEDESTRIAN_MODEL, t, z, h),
|
||||
# # ModelType.PEDESTRIAN_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 加载烟火模型
|
||||
# # ModelType.SMOGFIRE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.SMOGFIRE_MODEL, t, z, h),
|
||||
# # ModelType.SMOGFIRE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 加载钓鱼游泳模型
|
||||
# # ModelType.ANGLERSWIMMER_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.ANGLERSWIMMER_MODEL, t, z, h),
|
||||
# # ModelType.ANGLERSWIMMER_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 加载乡村模型
|
||||
# # ModelType.COUNTRYROAD_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.COUNTRYROAD_MODEL, t, z, h),
|
||||
# # ModelType.COUNTRYROAD_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 加载船只模型
|
||||
# # ModelType.SHIP_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: ShipModel(x, y, r, ModelType.SHIP_MODEL, t, z, h),
|
||||
# # ModelType.SHIP_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: obb_process(x)),
|
||||
# # # 百度AI图片识别模型
|
||||
# # ModelType.BAIDU_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: BaiduAiImageModel(x, y, r, ModelType.BAIDU_MODEL, t, z, h),
|
||||
# # ModelType.BAIDU_MODEL,
|
||||
# # lambda x, y, z: baidu_label(x, y, z),
|
||||
# # lambda x: baidu_process(x)),
|
||||
# # # 航道模型
|
||||
# # ModelType.CHANNEL_EMERGENCY_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.CHANNEL_EMERGENCY_MODEL, t, z, h),
|
||||
# # ModelType.CHANNEL_EMERGENCY_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 河道检测模型
|
||||
# # ModelType.RIVER2_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.RIVER2_MODEL, t, z, h),
|
||||
# # ModelType.RIVER2_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)),
|
||||
# # # 城管模型
|
||||
# # ModelType.CITY_MANGEMENT_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z, h),
|
||||
# # ModelType.CITY_MANGEMENT_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 人员落水模型
|
||||
# # ModelType.DROWING_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.DROWING_MODEL, t, z, h),
|
||||
# # ModelType.DROWING_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 城市违章模型
|
||||
# # ModelType.NOPARKING_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.NOPARKING_MODEL, t, z, h),
|
||||
# # ModelType.NOPARKING_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 城市公路模型
|
||||
# # ModelType.CITYROAD_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.CITYROAD_MODEL, t, z, h),
|
||||
# # ModelType.CITYROAD_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)),
|
||||
# # # 加载坑槽模型
|
||||
# # ModelType.POTHOLE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.POTHOLE_MODEL, t, z, h),
|
||||
# # ModelType.POTHOLE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)
|
||||
# # ),
|
||||
# # # 加载船只综合检测模型
|
||||
# # ModelType.CHANNEL2_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: MultiModel(x, y, r, ModelType.CHANNEL2_MODEL, t, z, h),
|
||||
# # ModelType.CHANNEL2_MODEL,
|
||||
# # lambda x, y, z: dynamics_label(x, y, z),
|
||||
# # lambda x: channel2_process(x)
|
||||
# # ),
|
||||
# # # 河道检测模型
|
||||
# # ModelType.RIVERT_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.RIVERT_MODEL, t, z, h),
|
||||
# # ModelType.RIVERT_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)),
|
||||
# # # 加载森林人群模型
|
||||
# # ModelType.FORESTCROWD_FARM_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FORESTCROWD_FARM_MODEL, t, z, h),
|
||||
# # ModelType.FORESTCROWD_FARM_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 加载交通模型
|
||||
# # ModelType.TRAFFICFORDSJ_FARM_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFICFORDSJ_FARM_MODEL, t, z, h),
|
||||
# # ModelType.TRAFFICFORDSJ_FARM_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 加载智慧工地模型
|
||||
# # ModelType.SMARTSITE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h),
|
||||
# # ModelType.SMARTSITE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
|
||||
# # # 加载垃圾模型
|
||||
# # ModelType.RUBBISH_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h),
|
||||
# # ModelType.RUBBISH_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
|
||||
# # # 加载烟花模型
|
||||
# # ModelType.FIREWORK_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h),
|
||||
# # ModelType.FIREWORK_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 加载高速公路抛撒物模型
|
||||
# # ModelType.TRAFFIC_SPILL_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_SPILL_MODEL, t, z, h),
|
||||
# # ModelType.TRAFFIC_SPILL_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 加载高速公路危化品模型
|
||||
# # ModelType.TRAFFIC_CTHC_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_CTHC_MODEL, t, z, h),
|
||||
# # ModelType.TRAFFIC_CTHC_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: model_process(x)
|
||||
# # ),
|
||||
# # # 加载光伏板异常检测模型
|
||||
# # ModelType.TRAFFIC_PANNEL_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.TRAFFIC_PANNEL_MODEL, t, z, h),
|
||||
# # ModelType.TRAFFIC_PANNEL_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 加载自研车牌检测模型
|
||||
# # ModelType.CITY_CARPLATE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: CARPLATEModel(x, y, r, ModelType.CITY_CARPLATE_MODEL, t, z, h),
|
||||
# # ModelType.CITY_CARPLATE_MODEL,
|
||||
# # None,
|
||||
# # lambda x: immulti_process(x)
|
||||
# # ),
|
||||
# # # 加载红外行人检测模型
|
||||
# # ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h),
|
||||
# # ModelType.CITY_INFRAREDPERSON_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 加载夜间烟火检测模型
|
||||
# # ModelType.CITY_NIGHTFIRESMOKE_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_NIGHTFIRESMOKE_MODEL, t, z, h),
|
||||
# # ModelType.CITY_NIGHTFIRESMOKE_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# # # 加载密集人群计数检测模型
|
||||
# # ModelType.CITY_DENSECROWDCOUNT_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_DENSECROWDCOUNT_MODEL, t, z, h),
|
||||
# # ModelType.CITY_DENSECROWDCOUNT_MODEL,
|
||||
# # None,
|
||||
# # lambda x: cc_process(x)
|
||||
# # ),
|
||||
# # # 加载建筑物下行人检测模型
|
||||
# # ModelType.CITY_UNDERBUILDCOUNT_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_UNDERBUILDCOUNT_MODEL, t, z, h),
|
||||
# # ModelType.CITY_UNDERBUILDCOUNT_MODEL,
|
||||
# # None,
|
||||
# # lambda x: cc_process(x)
|
||||
# # ),
|
||||
# # # 加载火焰面积模型
|
||||
# # ModelType.CITY_FIREAREA_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.CITY_FIREAREA_MODEL, t, z, h),
|
||||
# # ModelType.CITY_FIREAREA_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: forest_process(x)
|
||||
# # ),
|
||||
# # # 加载安防模型
|
||||
# # ModelType.CITY_SECURITY_MODEL.value[1]: (
|
||||
# # lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_SECURITY_MODEL, t, z, h),
|
||||
# # ModelType.CITY_SECURITY_MODEL,
|
||||
# # lambda x, y, z: one_label(x, y, z),
|
||||
# # lambda x: detSeg_demo2(x)
|
||||
# # ),
|
||||
# }
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
enable_file_log: true
|
||||
enable_stderr: true
|
||||
base_path: "./appIOs/logs"
|
||||
log_name: "drgraph_aialg.log"
|
||||
log_fmt: "<green>{time: HH:mm:ss.SSS}</green> [<level>{level}</level>] - <level>{message}</level> @ <cyan>{file}:{line}</cyan> in <blue>{function}</blue>"
|
||||
level: "INFO"
|
||||
rotation: "00:00"
|
||||
retention: "1 days"
|
||||
encoding: "utf8"
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
|
||||
|
||||
"post_process":{ "name":"post_process","conf_thres":0.25,"iou_thres":0.45,"classes":5,"rainbows":[ [0,0,255],[0,255,0],[255,0,0],[255,0,255],[255,255,0],[255,129,0],[255,0,127],[127,255,0],[0,255,127],[0,127,255],[127,0,255],[255,127,255],[255,255,127],[127,255,255],[0,255,255],[255,127,255],[127,255,255], [0,127,0],[0,0,127],[0,255,255]] }
|
||||
|
||||
|
||||
}
|
||||
|
After Width: | Height: | Size: 117 KiB |
|
|
@ -0,0 +1 @@
|
|||
147,241,234,284,0.8705368041992188,3
|
||||
|
After Width: | Height: | Size: 113 KiB |
|
|
@ -0,0 +1,7 @@
|
|||
562,1,593,53,0.8735775947570801,3
|
||||
278,0,308,51,0.876518726348877,3
|
||||
159,0,189,45,0.8780574798583984,3
|
||||
40,0,73,39,0.8795464038848877,3
|
||||
397,41,428,115,0.8820486068725586,3
|
||||
356,36,389,113,0.8837814331054688,3
|
||||
475,39,510,112,0.8847465515136719,3
|
||||
|
After Width: | Height: | Size: 113 KiB |
|
|
@ -0,0 +1,7 @@
|
|||
562,1,593,53,0.8735775947570801,3
|
||||
278,0,308,51,0.876518726348877,3
|
||||
159,0,189,45,0.8780574798583984,3
|
||||
40,0,73,39,0.8795464038848877,3
|
||||
397,41,428,115,0.8820486068725586,3
|
||||
356,36,389,113,0.8837814331054688,3
|
||||
475,39,510,112,0.8847465515136719,3
|
||||
|
After Width: | Height: | Size: 113 KiB |
|
|
@ -0,0 +1,7 @@
|
|||
562,1,593,53,0.8735775947570801,3
|
||||
278,0,308,51,0.876518726348877,3
|
||||
159,0,189,45,0.8780574798583984,3
|
||||
40,0,73,39,0.8795464038848877,3
|
||||
397,41,428,115,0.8820486068725586,3
|
||||
356,36,389,113,0.8837814331054688,3
|
||||
475,39,510,112,0.8847465515136719,3
|
||||
|
After Width: | Height: | Size: 113 KiB |
|
|
@ -0,0 +1,7 @@
|
|||
562,1,593,53,0.8735775947570801,3
|
||||
278,0,308,51,0.876518726348877,3
|
||||
159,0,189,45,0.8780574798583984,3
|
||||
40,0,73,39,0.8795464038848877,3
|
||||
397,41,428,115,0.8820486068725586,3
|
||||
356,36,389,113,0.8837814331054688,3
|
||||
475,39,510,112,0.8847465515136719,3
|
||||
|
After Width: | Height: | Size: 64 KiB |
|
After Width: | Height: | Size: 56 KiB |
|
After Width: | Height: | Size: 56 KiB |
|
After Width: | Height: | Size: 56 KiB |
|
After Width: | Height: | Size: 56 KiB |
|
|
@ -0,0 +1,21 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
# 分析状态枚举
|
||||
@unique
|
||||
class AnalysisStatus(Enum):
|
||||
|
||||
# 等待
|
||||
WAITING = "waiting"
|
||||
|
||||
# 分析中
|
||||
RUNNING = "running"
|
||||
|
||||
# 分析完成
|
||||
SUCCESS = "success"
|
||||
|
||||
# 超时
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
# 失败
|
||||
FAILED = "failed"
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
# 分析类型枚举
|
||||
@unique
|
||||
class AnalysisType(Enum):
|
||||
# 在线
|
||||
ONLINE = "1"
|
||||
|
||||
# 离线
|
||||
OFFLINE = "2"
|
||||
|
||||
# 图片
|
||||
IMAGE = "3"
|
||||
|
||||
# 录屏
|
||||
RECORDING = "9999"
|
||||
|
||||
# 转推流
|
||||
PULLTOPUSH = "10000"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
'''
|
||||
ocr官方文档: https://ai.baidu.com/ai-doc/OCR/zkibizyhz
|
||||
官方文档: https://ai.baidu.com/ai-doc/VEHICLE/rk3inf9tj
|
||||
参数1: 异常编号
|
||||
参数2: 异常英文描述
|
||||
参数3: 异常中文描述
|
||||
参数4: 0-异常信息统一输出为内部异常
|
||||
1-异常信息可以输出
|
||||
2-输出空的异常信息
|
||||
参数5: 指定异常重试的次数
|
||||
'''
|
||||
|
||||
|
||||
# 异常枚举
|
||||
@unique
|
||||
class BaiduSdkErrorEnum(Enum):
|
||||
|
||||
UNKNOWN_ERROR = (1, "Unknown error", "未知错误", 0, 0)
|
||||
|
||||
SERVICE_TEMPORARILY_UNAVAILABLE = (2, "Service temporarily unavailable", "服务暂不可用,请再次请求", 0, 3)
|
||||
|
||||
UNSUPPORTED_OPENAPI_METHOD = (3, "Unsupported openapi method", "调用的API不存在", 0, 0)
|
||||
|
||||
API_REQUEST_LIMIT_REACHED = (4, "Open api request limit reached", "请求量限制, 请稍后再试!", 1, 5)
|
||||
|
||||
NO_PERMISSION_TO_ACCESS_DATA = (6, "No permission to access data", "无权限访问该用户数据", 1, 0)
|
||||
|
||||
GET_SERVICE_TOKEN_FAILED = (13, "Get service token failed", "获取token失败", 0, 2)
|
||||
|
||||
IAM_CERTIFICATION_FAILED = (14, "IAM Certification failed", "IAM 鉴权失败", 0, 1)
|
||||
|
||||
APP_NOT_EXSITS_OR_CREATE_FAILED = (15, "app not exsits or create failed", "应用不存在或者创建失败", 0, 0)
|
||||
|
||||
API_DAILY_REQUEST_LIMIT_REACHED = (17, "Open api daily request limit reached", "每天请求量超限额!", 1, 2)
|
||||
|
||||
API_QPS_REQUEST_LIMIT_REACHED = (18, "Open api qps request limit reached", "QPS超限额!", 1, 10)
|
||||
|
||||
API_TOTAL_REQUEST_LIMIT_REACHED = (19, "Open api total request limit reached", "请求总量超限额!", 1, 2)
|
||||
|
||||
INVALID_TOKEN = (100, "Invalid parameter", "无效的access_token参数,token拉取失败", 0, 1)
|
||||
|
||||
ACCESS_TOKEN_INVALID_OR_NO_LONGER_VALID = (110, "Access token invalid or no longer valid", "access_token无效,token有效期为30天", 0, 1)
|
||||
|
||||
ACCESS_TOKEN_EXPIRED = (111, "Access token expired", "access token过期,token有效期为30天", 0, 1)
|
||||
|
||||
INTERNAL_ERROR = (282000, "internal error", "服务器内部错误", 0, 1)
|
||||
|
||||
INVALID_PARAM = (216100, "invalid param", "请求中包含非法参数!", 0, 1)
|
||||
|
||||
NOT_ENOUGH_PARAM = (216101, "not enough param", "缺少必须的参数!", 0, 0)
|
||||
|
||||
SERVICE_NOT_SUPPORT = (216102, "service not support", "请求了不支持的服务,请检查调用的url", 0, 0)
|
||||
|
||||
PARAM_TOO_LONG = (216103, "param too long", "请求中某些参数过长!", 1, 0)
|
||||
|
||||
APPID_NOT_EXIST = (216110, "appid not exist", "appid不存在", 0, 0)
|
||||
|
||||
EMPTY_IMAGE = (216200, "empty image", "图片为空!", 1, 0)
|
||||
|
||||
IMAGE_FORMAT_ERROR = (216201, "image format error", "上传的图片格式错误,现阶段我们支持的图片格式为:PNG、JPG、JPEG、BMP", 1, 0)
|
||||
|
||||
IMAGE_SIZE_ERROR = (216202, "image size error", "上传的图片大小错误,分辨率不高于4096*4096", 1, 0)
|
||||
|
||||
IMAGE_SIZE_BASE_ERROR = (216203, "image size error", "上传的图片编码有误", 1, 0)
|
||||
|
||||
RECOGNIZE_ERROR = (216630, "recognize error", "识别错误", 2, 2)
|
||||
|
||||
DETECT_ERROR = (216634, "detect error", "检测错误", 2, 2)
|
||||
|
||||
MISSING_PARAMETERS = (282003, "missing parameters: {参数名}", "请求参数缺失", 0, 0)
|
||||
|
||||
BATCH_ROCESSING_ERROR = (282005, "batch processing error", "处理批量任务时发生部分或全部错误", 0, 5)
|
||||
|
||||
BATCH_TASK_LIMIT_REACHED = (282006, "batch task limit reached", "批量任务处理数量超出限制,请将任务数量减少到10或10以下", 1, 5)
|
||||
|
||||
IMAGE_TRANSCODE_ERROR = (282100, "image transcode error", "图片压缩转码错误", 0, 1)
|
||||
|
||||
IMAGE_SPLIT_LIMIT_REACHED = (282101, "image split limit reached", "长图片切分数量超限!", 1, 1)
|
||||
|
||||
TARGET_DETECT_ERROR = (282102, "target detect error", "未检测到图片中识别目标!", 2, 1)
|
||||
|
||||
TARGET_RECOGNIZE_ERROR = (282103, "target recognize error", "图片目标识别错误!", 2, 1)
|
||||
|
||||
URLS_NOT_EXIT = (282110, "urls not exit", "URL参数不存在,请核对URL后再次提交!", 1, 0)
|
||||
|
||||
URL_FORMAT_ILLEGAL = (282111, "url format illegal", "URL格式非法!", 1, 0)
|
||||
|
||||
URL_DOWNLOAD_TIMEOUT = (282112, "url download timeout", "URL格式非法!", 1, 0)
|
||||
|
||||
URL_RESPONSE_INVALID = (282113, "url response invalid", "URL返回无效参数!", 1, 0)
|
||||
|
||||
URL_SIZE_ERROR = (282114, "url size error", "URL长度超过1024字节或为0!", 1, 0)
|
||||
|
||||
REQUEST_ID_NOT_EXIST = (282808, "request id: xxxxx not exist", "request id xxxxx 不存在", 0, 0)
|
||||
|
||||
RESULT_TYPE_ERROR = (282809, "result type error", "返回结果请求错误(不属于excel或json)", 0, 0)
|
||||
|
||||
IMAGE_RECOGNIZE_ERROR = (282810, "image recognize error", "图像识别错误", 2, 1)
|
||||
|
||||
INVALID_ARGUMENT = (283300, "Invalid argument", "入参格式有误,可检查下图片编码、代码格式是否有误", 1, 0)
|
||||
|
||||
INTERNAL_ERROR_2 = (336000, "Internal error", "服务器内部错误", 0, 0)
|
||||
|
||||
INVALID_ARGUMENT_2 = (336001, "Invalid Argument", "入参格式有误,比如缺少必要参数、图片编码错误等等,可检查下图片编码、代码格式是否有误", 0, 0)
|
||||
|
||||
SDK_IMAGE_SIZE_ERROR = ('SDK100', "image size error", "图片大小超限,最短边至少50px,最长边最大4096px ,建议长宽比3:1以内,图片请求格式支持:PNG、JPG、BMP", 1, 0)
|
||||
|
||||
SDK_IMAGE_LENGTH_ERROR = ('SDK101', "image length error", "图片边长不符合要求,最短边至少50px,最长边最大4096px ,建议长宽比3:1以内", 1, 0)
|
||||
|
||||
SDK_READ_IMAGE_FILE_ERROR = ('SDK102', "read image file error", "读取图片文件错误", 0, 1)
|
||||
|
||||
SDK_CONNECTION_OR_READ_DATA_TIME_OUT = ('SDK108', "connection or read data time out", "连接超时或读取数据超时,请检查本地网络设置、文件读取设置", 0, 3)
|
||||
|
||||
SDK_UNSUPPORTED_IMAGE_FORMAT = ('SDK109', "unsupported image format", "不支持的图片格式,当前支持以下几类图片:PNG、JPG、BMP", 1, 0)
|
||||
|
||||
|
||||
BAIDUERRORDATA = {
|
||||
BaiduSdkErrorEnum.UNKNOWN_ERROR.value[0]: BaiduSdkErrorEnum.UNKNOWN_ERROR,
|
||||
BaiduSdkErrorEnum.SERVICE_TEMPORARILY_UNAVAILABLE.value[0]: BaiduSdkErrorEnum.SERVICE_TEMPORARILY_UNAVAILABLE,
|
||||
BaiduSdkErrorEnum.UNSUPPORTED_OPENAPI_METHOD.value[0]: BaiduSdkErrorEnum.UNSUPPORTED_OPENAPI_METHOD,
|
||||
BaiduSdkErrorEnum.API_REQUEST_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.API_REQUEST_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.NO_PERMISSION_TO_ACCESS_DATA.value[0]: BaiduSdkErrorEnum.NO_PERMISSION_TO_ACCESS_DATA,
|
||||
BaiduSdkErrorEnum.GET_SERVICE_TOKEN_FAILED.value[0]: BaiduSdkErrorEnum.GET_SERVICE_TOKEN_FAILED,
|
||||
BaiduSdkErrorEnum.IAM_CERTIFICATION_FAILED.value[0]: BaiduSdkErrorEnum.IAM_CERTIFICATION_FAILED,
|
||||
BaiduSdkErrorEnum.APP_NOT_EXSITS_OR_CREATE_FAILED.value[0]: BaiduSdkErrorEnum.APP_NOT_EXSITS_OR_CREATE_FAILED,
|
||||
BaiduSdkErrorEnum.API_DAILY_REQUEST_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.API_DAILY_REQUEST_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.API_QPS_REQUEST_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.API_QPS_REQUEST_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.API_TOTAL_REQUEST_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.API_TOTAL_REQUEST_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.INVALID_TOKEN.value[0]: BaiduSdkErrorEnum.INVALID_TOKEN,
|
||||
BaiduSdkErrorEnum.ACCESS_TOKEN_INVALID_OR_NO_LONGER_VALID.value[0]: BaiduSdkErrorEnum.ACCESS_TOKEN_INVALID_OR_NO_LONGER_VALID,
|
||||
BaiduSdkErrorEnum.ACCESS_TOKEN_EXPIRED.value[0]: BaiduSdkErrorEnum.ACCESS_TOKEN_EXPIRED,
|
||||
BaiduSdkErrorEnum.INTERNAL_ERROR.value[0]: BaiduSdkErrorEnum.INTERNAL_ERROR,
|
||||
BaiduSdkErrorEnum.INVALID_PARAM.value[0]: BaiduSdkErrorEnum.INVALID_PARAM,
|
||||
BaiduSdkErrorEnum.NOT_ENOUGH_PARAM.value[0]: BaiduSdkErrorEnum.NOT_ENOUGH_PARAM,
|
||||
BaiduSdkErrorEnum.SERVICE_NOT_SUPPORT.value[0]: BaiduSdkErrorEnum.SERVICE_NOT_SUPPORT,
|
||||
BaiduSdkErrorEnum.PARAM_TOO_LONG.value[0]: BaiduSdkErrorEnum.PARAM_TOO_LONG,
|
||||
BaiduSdkErrorEnum.APPID_NOT_EXIST.value[0]: BaiduSdkErrorEnum.APPID_NOT_EXIST,
|
||||
BaiduSdkErrorEnum.EMPTY_IMAGE.value[0]: BaiduSdkErrorEnum.EMPTY_IMAGE,
|
||||
BaiduSdkErrorEnum.IMAGE_FORMAT_ERROR.value[0]: BaiduSdkErrorEnum.IMAGE_FORMAT_ERROR,
|
||||
BaiduSdkErrorEnum.IMAGE_SIZE_ERROR.value[0]: BaiduSdkErrorEnum.IMAGE_SIZE_ERROR,
|
||||
BaiduSdkErrorEnum.IMAGE_SIZE_BASE_ERROR.value[0]: BaiduSdkErrorEnum.IMAGE_SIZE_BASE_ERROR,
|
||||
BaiduSdkErrorEnum.RECOGNIZE_ERROR.value[0]: BaiduSdkErrorEnum.RECOGNIZE_ERROR,
|
||||
BaiduSdkErrorEnum.DETECT_ERROR.value[0]: BaiduSdkErrorEnum.DETECT_ERROR,
|
||||
BaiduSdkErrorEnum.MISSING_PARAMETERS.value[0]: BaiduSdkErrorEnum.MISSING_PARAMETERS,
|
||||
BaiduSdkErrorEnum.BATCH_ROCESSING_ERROR.value[0]: BaiduSdkErrorEnum.BATCH_ROCESSING_ERROR,
|
||||
BaiduSdkErrorEnum.BATCH_TASK_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.BATCH_TASK_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.IMAGE_TRANSCODE_ERROR.value[0]: BaiduSdkErrorEnum.IMAGE_TRANSCODE_ERROR,
|
||||
BaiduSdkErrorEnum.IMAGE_SPLIT_LIMIT_REACHED.value[0]: BaiduSdkErrorEnum.IMAGE_SPLIT_LIMIT_REACHED,
|
||||
BaiduSdkErrorEnum.TARGET_DETECT_ERROR.value[0]: BaiduSdkErrorEnum.TARGET_DETECT_ERROR,
|
||||
BaiduSdkErrorEnum.TARGET_RECOGNIZE_ERROR.value[0]: BaiduSdkErrorEnum.TARGET_RECOGNIZE_ERROR,
|
||||
BaiduSdkErrorEnum.URL_SIZE_ERROR.value[0]: BaiduSdkErrorEnum.URL_SIZE_ERROR,
|
||||
BaiduSdkErrorEnum.REQUEST_ID_NOT_EXIST.value[0]: BaiduSdkErrorEnum.REQUEST_ID_NOT_EXIST,
|
||||
BaiduSdkErrorEnum.RESULT_TYPE_ERROR.value[0]: BaiduSdkErrorEnum.RESULT_TYPE_ERROR,
|
||||
BaiduSdkErrorEnum.IMAGE_RECOGNIZE_ERROR.value[0]: BaiduSdkErrorEnum.IMAGE_RECOGNIZE_ERROR,
|
||||
BaiduSdkErrorEnum.INVALID_ARGUMENT.value[0]: BaiduSdkErrorEnum.INVALID_ARGUMENT,
|
||||
BaiduSdkErrorEnum.INTERNAL_ERROR_2.value[0]: BaiduSdkErrorEnum.INTERNAL_ERROR_2,
|
||||
BaiduSdkErrorEnum.INVALID_ARGUMENT_2.value[0]: BaiduSdkErrorEnum.INVALID_ARGUMENT_2,
|
||||
BaiduSdkErrorEnum.SDK_IMAGE_SIZE_ERROR.value[0]: BaiduSdkErrorEnum.SDK_IMAGE_SIZE_ERROR,
|
||||
BaiduSdkErrorEnum.SDK_IMAGE_LENGTH_ERROR.value[0]: BaiduSdkErrorEnum.SDK_IMAGE_LENGTH_ERROR,
|
||||
BaiduSdkErrorEnum.SDK_READ_IMAGE_FILE_ERROR.value[0]: BaiduSdkErrorEnum.SDK_READ_IMAGE_FILE_ERROR,
|
||||
BaiduSdkErrorEnum.SDK_CONNECTION_OR_READ_DATA_TIME_OUT.value[0]: BaiduSdkErrorEnum.SDK_CONNECTION_OR_READ_DATA_TIME_OUT,
|
||||
BaiduSdkErrorEnum.SDK_UNSUPPORTED_IMAGE_FORMAT.value[0]: BaiduSdkErrorEnum.SDK_UNSUPPORTED_IMAGE_FORMAT,
|
||||
BaiduSdkErrorEnum.URLS_NOT_EXIT.value[0]: BaiduSdkErrorEnum.URLS_NOT_EXIT,
|
||||
BaiduSdkErrorEnum.URL_FORMAT_ILLEGAL.value[0]: BaiduSdkErrorEnum.URL_FORMAT_ILLEGAL,
|
||||
BaiduSdkErrorEnum.URL_DOWNLOAD_TIMEOUT.value[0]: BaiduSdkErrorEnum.URL_DOWNLOAD_TIMEOUT,
|
||||
BaiduSdkErrorEnum.URL_RESPONSE_INVALID.value[0]: BaiduSdkErrorEnum.URL_RESPONSE_INVALID
|
||||
}
|
||||
|
||||
@unique
|
||||
class VehicleEnum(Enum):
|
||||
CAR = ("car", "小汽车", 0)
|
||||
TRICYCLE = ("tricycle", "三轮车", 1)
|
||||
MOTORBIKE = ("motorbike", "摩托车", 2)
|
||||
CARPLATE = ("carplate", "车牌", 3)
|
||||
TRUCK = ("truck", "卡车", 4)
|
||||
BUS = ("bus", "巴士", 5)
|
||||
|
||||
|
||||
VehicleEnumVALUE={
|
||||
VehicleEnum.CAR.value[0]: VehicleEnum.CAR,
|
||||
VehicleEnum.TRICYCLE.value[0]: VehicleEnum.TRICYCLE,
|
||||
VehicleEnum.MOTORBIKE.value[0]: VehicleEnum.MOTORBIKE,
|
||||
VehicleEnum.CARPLATE.value[0]: VehicleEnum.CARPLATE,
|
||||
VehicleEnum.TRUCK.value[0]: VehicleEnum.TRUCK,
|
||||
VehicleEnum.BUS.value[0]: VehicleEnum.BUS
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
# 异常枚举
|
||||
@unique
|
||||
class ExceptionType(Enum):
|
||||
|
||||
OR_VIDEO_ADDRESS_EXCEPTION = ("SP000", "未拉取到视频流, 请检查拉流地址是否有视频流!")
|
||||
|
||||
ANALYSE_TIMEOUT_EXCEPTION = ("SP001", "AI分析超时!")
|
||||
|
||||
PULLSTREAM_TIMEOUT_EXCEPTION = ("SP002", "原视频拉流超时!")
|
||||
|
||||
READSTREAM_TIMEOUT_EXCEPTION = ("SP003", "原视频读取视频流超时!")
|
||||
|
||||
GET_VIDEO_URL_EXCEPTION = ("SP004", "获取视频播放地址失败!")
|
||||
|
||||
GET_VIDEO_URL_TIMEOUT_EXCEPTION = ("SP005", "获取原视频播放地址超时!")
|
||||
|
||||
PULL_STREAM_URL_EXCEPTION = ("SP006", "拉流地址不能为空!")
|
||||
|
||||
PUSH_STREAM_URL_EXCEPTION = ("SP007", "推流地址不能为空!")
|
||||
|
||||
PUSH_STREAM_TIME_EXCEPTION = ("SP008", "未生成本地视频地址!")
|
||||
|
||||
AI_MODEL_MATCH_EXCEPTION = ("SP009", "未匹配到对应的AI模型!")
|
||||
|
||||
ILLEGAL_PARAMETER_FORMAT = ("SP010", "非法参数格式!")
|
||||
|
||||
PUSH_STREAMING_CHANNEL_IS_OCCUPIED = ("SP011", "推流通道可能被占用, 请稍后再试!")
|
||||
|
||||
VIDEO_RESOLUTION_EXCEPTION = ("SP012", "不支持该分辨率类型的视频,请切换分辨率再试!")
|
||||
|
||||
READ_IAMGE_URL_EXCEPTION = ("SP013", "未能解析图片地址!")
|
||||
|
||||
DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED = ("SP014", "不支持该类型的检测目标!")
|
||||
|
||||
WRITE_STREAM_EXCEPTION = ("SP015", "写流异常!")
|
||||
|
||||
OR_VIDEO_DO_NOT_EXEIST_EXCEPTION = ("SP016", "原视频不存在!")
|
||||
|
||||
MODEL_LOADING_EXCEPTION = ("SP017", "模型加载异常!")
|
||||
|
||||
MODEL_ANALYSE_EXCEPTION = ("SP018", "算法模型分析异常!")
|
||||
|
||||
AI_MODEL_CONFIG_EXCEPTION = ("SP019", "模型配置不能为空!")
|
||||
|
||||
AI_MODEL_GET_CONFIG_EXCEPTION = ("SP020", "获取模型配置异常, 请检查模型配置是否正确!")
|
||||
|
||||
MODEL_GROUP_LIMIT_EXCEPTION = ("SP021", "模型组合个数超过限制!")
|
||||
|
||||
MODEL_NOT_SUPPORT_VIDEO_EXCEPTION = ("SP022", "%s不支持视频识别!")
|
||||
|
||||
MODEL_NOT_SUPPORT_IMAGE_EXCEPTION = ("SP023", "%s不支持图片识别!")
|
||||
|
||||
THE_DETECTION_TARGET_CANNOT_BE_EMPTY = ("SP024", "检测目标不能为空!")
|
||||
|
||||
URL_ADDRESS_ACCESS_FAILED = ("SP025", "URL地址访问失败, 请检测URL地址是否正确!")
|
||||
|
||||
UNIVERSAL_TEXT_RECOGNITION_FAILED = ("SP026", "识别失败!")
|
||||
|
||||
COORDINATE_ACQUISITION_FAILED = ("SP027", "飞行坐标识别异常!")
|
||||
|
||||
PUSH_STREAM_EXCEPTION = ("SP028", "推流异常!")
|
||||
|
||||
MODEL_DUPLICATE_EXCEPTION = ("SP029", "存在重复模型配置!")
|
||||
|
||||
DETECTION_TARGET_NOT_SUPPORT = ("SP031", "存在不支持的检测目标!")
|
||||
|
||||
TASK_EXCUTE_TIMEOUT = ("SP032", "任务执行超时!")
|
||||
|
||||
PUSH_STREAM_URL_IS_NULL = ("SP033", "拉流、推流地址不能为空!")
|
||||
|
||||
PULL_STREAM_NUM_LIMIT_EXCEPTION = ("SP034", "转推流数量超过限制!")
|
||||
|
||||
NOT_REQUESTID_TASK_EXCEPTION = ("SP993", "未查询到该任务,无法停止任务!")
|
||||
|
||||
NO_RESOURCES = ("SP995", "服务器暂无资源可以使用,请稍后30秒后再试!")
|
||||
|
||||
NO_CPU_RESOURCES = ("SP996", "暂无CPU资源可以使用,请稍后再试!")
|
||||
|
||||
SERVICE_COMMON_EXCEPTION = ("SP997", "公共服务异常!")
|
||||
|
||||
NO_GPU_RESOURCES = ("SP998", "暂无GPU资源可以使用,请稍后再试!")
|
||||
|
||||
SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常!")
|
||||
|
|
@ -0,0 +1,762 @@
|
|||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from DrGraph.util.Constant import COLOR
|
||||
|
||||
sys.path.extend(['..', '../AIlib2'])
|
||||
from segutils.segmodel import SegModel
|
||||
from utilsK.queRiver import riverDetSegMixProcess_N
|
||||
from segutils.trafficUtils import tracfficAccidentMixFunction_N
|
||||
from utilsK.drownUtils import mixDrowing_water_postprocess_N
|
||||
from utilsK.noParkingUtils import mixNoParking_road_postprocess_N
|
||||
from utilsK.illParkingUtils import illParking_postprocess
|
||||
from DMPR import DMPRModel
|
||||
from DMPRUtils.jointUtil import dmpr_yolo
|
||||
from yolov5 import yolov5Model
|
||||
from stdc import stdcModel
|
||||
from AI import default_mix
|
||||
from DMPRUtils.jointUtil import dmpr_yolo_stdc
|
||||
|
||||
'''
|
||||
参数说明
|
||||
1. 编号
|
||||
2. 模型编号
|
||||
3. 模型名称
|
||||
4. 选用的模型名称
|
||||
'''
|
||||
|
||||
|
||||
@unique
|
||||
class ModelType2(Enum):
|
||||
WATER_SURFACE_MODEL = ("1", "001", "河道模型", 'river', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["排口", "水生植被", "其它", "漂浮物", "污染排口", "菜地", "违建", "岸坡垃圾"],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6,7] ],###控制哪些检测类别显示、输出
|
||||
'trackPar': {
|
||||
'sort_max_age': 2, # 跟踪链断裂时允许目标消失最大的次数。超过之后,会认为是新的目标。
|
||||
'sort_min_hits': 3, # 每隔目标连续出现的次数,超过这个次数才认为是一个目标。
|
||||
'sort_iou_thresh': 0.2, # 检测最小的置信度。
|
||||
'det_cnt': 10, # 每隔几次做一个跟踪和检测,默认10。
|
||||
'windowsize': 29, # 轨迹平滑长度,一定是奇数,表示每隔几帧做一平滑,默认29。一个目标在多个帧中出现,每一帧中都有一个位置,这些位置的连线交轨迹。
|
||||
'patchCnt': 100, # 每次送入图像的数量,不宜少于100帧。
|
||||
},
|
||||
'postProcess':{'function':riverDetSegMixProcess_N,'pars':{'slopeIndex':[1,3,4,7], 'riverIou':0.1}}, #分割和检测混合处理的函数
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 80,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/river/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{
|
||||
'half':True,
|
||||
'device':'cuda:0' ,
|
||||
'conf_thres':0.25,
|
||||
'iou_thres':0.45,
|
||||
'allowedList':[0,1,2,3],
|
||||
'segRegionCnt':1,
|
||||
'trtFlag_det':False,
|
||||
'trtFlag_seg':False,
|
||||
"score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
},
|
||||
{
|
||||
'weight':'../AIlib2/weights/conf/river/stdc_360X640.pth',
|
||||
'par':{
|
||||
'modelSize':(640,360),
|
||||
'mean':(0.485, 0.456, 0.406),
|
||||
'std' :(0.229, 0.224, 0.225),
|
||||
'numpy':False,
|
||||
'RGB_convert_first':True,
|
||||
'seg_nclass':2},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
}
|
||||
|
||||
],
|
||||
})
|
||||
|
||||
FOREST_FARM_MODEL = ("2", "002", "森林模型", 'forest2', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["林斑", "病死树", "行人", "火焰", "烟雾"],
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/forest2/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,
|
||||
'device':'cuda:0' ,
|
||||
'conf_thres':0.25,
|
||||
'iou_thres':0.45,
|
||||
'allowedList':[0,1,2,3],
|
||||
'segRegionCnt':1,
|
||||
'trtFlag_det':False,
|
||||
'trtFlag_seg':False,
|
||||
"score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 }
|
||||
},
|
||||
}
|
||||
],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 80,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
TRAFFIC_FARM_MODEL = ("3", "003", "交通模型", 'highWay2', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["行人", "车辆", "纵向裂缝", "横向裂缝", "修补", "网状裂纹", "坑槽", "块状裂纹", "积水", "事故"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':5,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{
|
||||
'function':tracfficAccidentMixFunction_N,
|
||||
'pars':{
|
||||
'RoadArea': 16000,
|
||||
'vehicleArea': 10,
|
||||
'roadVehicleAngle': 15,
|
||||
'speedRoadVehicleAngleMax': 75,
|
||||
'radius': 50 ,
|
||||
'roundness': 1.0,
|
||||
'cls': 9,
|
||||
'vehicleFactor': 0.1,
|
||||
'cls':9,
|
||||
'confThres':0.25,
|
||||
'roadIou':0.6,
|
||||
'vehicleFlag':False,
|
||||
'distanceFlag': False,
|
||||
'modelSize':(640,360),
|
||||
}
|
||||
},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/highWay2/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{
|
||||
'half':True,
|
||||
'device':'cuda:0' ,
|
||||
'conf_thres':0.25,
|
||||
'iou_thres':0.45,
|
||||
'allowedList':[0,1,2,3],
|
||||
'segRegionCnt':1,
|
||||
'trtFlag_det':False,
|
||||
'trtFlag_seg':False,
|
||||
"score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
},
|
||||
{
|
||||
'weight':'../AIlib2/weights/conf/highWay2/stdc_360X640.pth',
|
||||
'par':{
|
||||
'modelSize':(640,360),
|
||||
'mean':(0.485, 0.456, 0.406),
|
||||
'std' :(0.229, 0.224, 0.225),
|
||||
'predResize':True,
|
||||
'numpy':False,
|
||||
'RGB_convert_first':True,
|
||||
'seg_nclass':3},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,5,6,7,8,9] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.25,
|
||||
"classes": 9,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 20,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'segLineShow': False,
|
||||
'waterLineWidth': 2
|
||||
}
|
||||
})
|
||||
|
||||
EPIDEMIC_PREVENTION_MODEL = ("4", "004", "防疫模型", None, None)
|
||||
|
||||
PLATE_MODEL = ("5", "005", "车牌模型", None, None)
|
||||
|
||||
VEHICLE_MODEL = ("6", "006", "车辆模型", 'vehicle', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["车辆"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/vehicle/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,
|
||||
'device':'cuda:0' ,
|
||||
'conf_thres':0.25,
|
||||
'iou_thres':0.45,
|
||||
'allowedList':[0,1,2,3],
|
||||
'segRegionCnt':1,
|
||||
'trtFlag_det':False,
|
||||
'trtFlag_seg':False,
|
||||
"score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'segLineShow': False,
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
PEDESTRIAN_MODEL = ("7", "007", "行人模型", 'pedestrian', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["行人"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/pedestrian/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
SMOGFIRE_MODEL = ("8", "008", "烟火模型", 'smogfire', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["烟雾", "火焰"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/smogfire/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
#'weight':'../AIlib2/weights/conf/%s/yolov5.pt'%(opt['business'] ),
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
ANGLERSWIMMER_MODEL = ("9", "009", "钓鱼游泳模型", 'AnglerSwimmer', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["钓鱼", "游泳"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/AnglerSwimmer/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
},
|
||||
})
|
||||
|
||||
COUNTRYROAD_MODEL = ("10", "010", "乡村模型", 'countryRoad', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["违法种植"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/countryRoad/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
SHIP_MODEL = ("11", "011", "船只模型", 'ship2', lambda device, gpuName: {
|
||||
'obbModelPar': {
|
||||
'labelnames': ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "船只"],
|
||||
'model_size': (608, 608),
|
||||
'K': 100,
|
||||
'conf_thresh': 0.3,
|
||||
'down_ratio': 4,
|
||||
'num_classes': 15,
|
||||
'dataset': 'dota',
|
||||
'heads': {
|
||||
'hm': None,
|
||||
'wh': 10,
|
||||
'reg': 2,
|
||||
'cls_theta': 1
|
||||
},
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (1, 1, 1),
|
||||
'half': False,
|
||||
'test_flag': True,
|
||||
'decoder': None,
|
||||
'weights': '../AIlib2/weights/ship2/obb_608X608_%s_fp16.engine' % gpuName
|
||||
},
|
||||
'trackPar': {
|
||||
'sort_max_age': 2, # 跟踪链断裂时允许目标消失最大的次数。超过之后,会认为是新的目标。
|
||||
'sort_min_hits': 3, # 每隔目标连续出现的次数,超过这个次数才认为是一个目标。
|
||||
'sort_iou_thresh': 0.2, # 检测最小的置信度。
|
||||
'det_cnt': 10, # 每隔几次做一个跟踪和检测,默认10。
|
||||
'windowsize': 29, # 轨迹平滑长度,一定是奇数,表示每隔几帧做一平滑,默认29。一个目标在多个帧中出现,每一帧中都有一个位置,这些位置的连线交轨迹。
|
||||
'patchCnt': 100, # 每次送入图像的数量,不宜少于100帧。
|
||||
},
|
||||
'device': "cuda:%s" % device,
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'drawBox': False,
|
||||
'drawPar': {
|
||||
"rainbows": COLOR,
|
||||
'digitWordFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'wordSize': 40,
|
||||
'fontSize': 1.0,
|
||||
'label_location': 'leftTop'
|
||||
}
|
||||
},
|
||||
'labelnames': ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "船只"]
|
||||
})
|
||||
|
||||
BAIDU_MODEL = ("12", "012", "百度AI图片识别模型", None, None)
|
||||
|
||||
CHANNEL_EMERGENCY_MODEL = ("13", "013", "航道模型", 'channelEmergency', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["人"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/channelEmergency/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
#'weight':'../AIlib2/weights/conf/%s/yolov5.pt'%(opt['business'] ),
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
RIVER2_MODEL = ("15", "015", "河道检测模型", 'river2', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["漂浮物", "岸坡垃圾", "排口", "违建", "菜地", "水生植物", "河湖人员", "钓鱼人员", "船只",
|
||||
"蓝藻"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':riverDetSegMixProcess_N,'pars':{'slopeIndex':[1,3,4,7], 'riverIou':0.1}}, #分割和检测混合处理的函数
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/river2/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
},
|
||||
{
|
||||
'weight':'../AIlib2/weights/conf/river2/stdc_360X640.pth',
|
||||
'par':{
|
||||
'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True,'seg_nclass':2},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.3,
|
||||
"ovlap_thres_crossCategory": 0.65,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 80,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
CITY_MANGEMENT_MODEL = ("16", "016", "城管模型", 'cityMangement2', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["车辆", "垃圾", "商贩", "违停"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':5,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{
|
||||
'function':dmpr_yolo_stdc,
|
||||
'pars':{'carCls':0 ,'illCls':3,'scaleRatio':0.5,'border':80}
|
||||
},
|
||||
'models':[
|
||||
{
|
||||
'weight':"../AIlib2/weights/cityMangement3/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.8,"1":0.5,"2":0.5,"3":0.5 } }
|
||||
|
||||
},
|
||||
{
|
||||
'weight':"../AIlib2/weights/cityMangement3/dmpr_%s.engine"% gpuName,###DMPR模型路径
|
||||
'par':{
|
||||
'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640,
|
||||
'name':'dmpr'
|
||||
},
|
||||
'model':DMPRModel,
|
||||
'name':'dmpr'
|
||||
},
|
||||
{
|
||||
'weight':"../AIlib2/weights/cityMangement3/stdc_360X640_%s_fp16.engine"% gpuName,###分割模型路径
|
||||
'par':{
|
||||
'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'predResize':True,'numpy':False, 'RGB_convert_first':True,'seg_nclass':2},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,5,6,7,8,9] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 20,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 2
|
||||
}
|
||||
})
|
||||
|
||||
DROWING_MODEL = ("17", "017", "人员落水模型", 'drowning', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["人头", "人", "船只"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':mixDrowing_water_postprocess_N,
|
||||
'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/drowning/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
},
|
||||
{
|
||||
'weight':'../AIlib2/weights/conf/drowning/stdc_360X640.pth',
|
||||
'par':{
|
||||
'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'predResize':True,'numpy':False, 'RGB_convert_first':True,'seg_nclass':2},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,5,6,7,8,9] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.25,
|
||||
"classes": 9,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 20,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'segLineShow': False,
|
||||
'waterLineWidth': 2
|
||||
}
|
||||
})
|
||||
|
||||
NOPARKING_MODEL = (
|
||||
"18", "018", "城市违章模型", 'noParking', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["车辆", "违停"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':mixNoParking_road_postprocess_N,
|
||||
'pars': { 'roundness': 0.3, 'cls': 9, 'laneArea': 10, 'laneAngleCha': 5 ,'RoadArea': 16000,'fitOrder':2, 'modelSize':(640,360)}
|
||||
} ,
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/noParking/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
},
|
||||
{
|
||||
'weight':'../AIlib2/weights/conf/noParking/stdc_360X640.pth',
|
||||
'par':{
|
||||
'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'predResize':True,'numpy':False, 'RGB_convert_first':True,'seg_nclass':4},###分割模型预处理参数
|
||||
'model':stdcModel,
|
||||
'name':'stdc'
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,5,6,7,8,9] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.25,
|
||||
"classes": 9,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 20,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'segLineShow': False,
|
||||
'waterLineWidth': 2
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
CITYROAD_MODEL = ("20", "020", "城市公路模型", 'cityRoad', lambda device, gpuName: {
|
||||
'device': device,
|
||||
'labelnames': ["护栏", "交通标志", "非交通标志", "施工", "施工"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':10,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/cityRoad/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.8,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } },
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6] ],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.8,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
}
|
||||
})
|
||||
|
||||
POTHOLE_MODEL = ("23", "023", "坑槽检测模型", 'pothole', lambda device, gpuName: { # 目前集成到另外的模型中去了 不单独使用
|
||||
'device': device,
|
||||
'labelnames': ["坑槽"],
|
||||
'trackPar':{'sort_max_age':2,'sort_min_hits':3,'sort_iou_thresh':0.2,'det_cnt':3,'windowsize':29,'patchCnt':100},
|
||||
'postProcess':{'function':default_mix,'pars':{ }},
|
||||
'models':
|
||||
[
|
||||
{
|
||||
'weight':"../AIlib2/weights/pothole/yolov5_%s_fp16.engine"% gpuName,###检测模型路径
|
||||
'name':'yolov5',
|
||||
'model':yolov5Model,
|
||||
'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3}},
|
||||
}
|
||||
],
|
||||
'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0]],###控制哪些检测类别显示、输出
|
||||
'postFile': {
|
||||
"name": "post_process",
|
||||
"conf_thres": 0.25,
|
||||
"iou_thres": 0.45,
|
||||
"classes": 5,
|
||||
"rainbows": COLOR
|
||||
},
|
||||
'txtFontSize': 40,
|
||||
'digitFont': {
|
||||
'line_thickness': 2,
|
||||
'boxLine_thickness': 1,
|
||||
'fontSize': 1.0,
|
||||
'segLineShow': False,
|
||||
'waterLineColor': (0, 255, 255),
|
||||
'waterLineWidth': 3
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@staticmethod
|
||||
def checkCode(code):
|
||||
for model in ModelType2:
|
||||
if model.value[1] == code:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
'''
|
||||
参数1: 检测目标名称
|
||||
参数2: 检测目标
|
||||
参数3: 初始化百度检测客户端
|
||||
'''
|
||||
|
||||
|
||||
@unique
|
||||
class BaiduModelTarget2(Enum):
|
||||
VEHICLE_DETECTION = (
|
||||
"车辆检测", 0, lambda client0, client1, url, request_id: client0.vehicleDetectUrl(url, request_id))
|
||||
|
||||
HUMAN_DETECTION = (
|
||||
"人体检测与属性识别", 1, lambda client0, client1, url, request_id: client1.bodyAttr(url, request_id))
|
||||
|
||||
PEOPLE_COUNTING = ("人流量统计", 2, lambda client0, client1, url, request_id: client1.bodyNum(url, request_id))
|
||||
|
||||
|
||||
BAIDU_MODEL_TARGET_CONFIG2 = {
|
||||
BaiduModelTarget2.VEHICLE_DETECTION.value[1]: BaiduModelTarget2.VEHICLE_DETECTION,
|
||||
BaiduModelTarget2.HUMAN_DETECTION.value[1]: BaiduModelTarget2.HUMAN_DETECTION,
|
||||
BaiduModelTarget2.PEOPLE_COUNTING.value[1]: BaiduModelTarget2.PEOPLE_COUNTING
|
||||
}
|
||||
|
||||
EPIDEMIC_PREVENTION_CONFIG = {1: "行程码", 2: "健康码"}
|
||||
|
||||
|
||||
# 模型分析方式
|
||||
@unique
|
||||
class ModelMethodTypeEnum2(Enum):
|
||||
# 方式一: 正常识别方式
|
||||
NORMAL = 1
|
||||
|
||||
# 方式二: 追踪识别方式
|
||||
TRACE = 2
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
# 录屏状态枚举
|
||||
@unique
|
||||
class RecordingStatus(Enum):
|
||||
|
||||
RECORDING_WAITING = ("5", "待录制")
|
||||
|
||||
RECORDING_RETRYING = ("10", "重试中")
|
||||
|
||||
RECORDING_RUNNING = ("15", "录制中")
|
||||
|
||||
RECORDING_SUCCESS = ("20", "录制完成")
|
||||
|
||||
RECORDING_TIMEOUT = ("25", "录制超时")
|
||||
|
||||
RECORDING_FAILED = ("30", "录制失败")
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
|
||||
@unique
|
||||
class PushStreamStatus(Enum):
|
||||
WAITING = (5, "待推流")
|
||||
|
||||
RETRYING = (10, "重试中")
|
||||
|
||||
RUNNING = (15, "推流中")
|
||||
|
||||
STOPPING = (20, "停止中")
|
||||
|
||||
SUCCESS = (25, "完成")
|
||||
|
||||
TIMEOUT = (30, "超时")
|
||||
|
||||
FAILED = (35, "失败")
|
||||
|
||||
|
||||
@unique
|
||||
class ExecuteStatus(Enum):
|
||||
WAITING = (5, "待执行")
|
||||
|
||||
RUNNING = (10, "执行中")
|
||||
|
||||
STOPPING = (15, "停止中")
|
||||
|
||||
SUCCESS = (20, "执行完成")
|
||||
|
||||
TIMEOUT = (25, "超时")
|
||||
|
||||
FAILED = (30, "失败")
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# 编码格式
|
||||
UTF_8 = "utf-8"
|
||||
|
||||
# 文件读模式
|
||||
R = 'r'
|
||||
ON_OR = "_on_or_"
|
||||
ON_AI = "_on_ai_"
|
||||
MP4 = ".mp4"
|
||||
# 初始化进度
|
||||
init_progess = "0.0000"
|
||||
# 进度100%
|
||||
success_progess = "1.0000"
|
||||
|
||||
# 拉流每帧图片缩小宽度大小限制, 大于1400像素缩小一半, 小于1400像素不变
|
||||
width = 1400
|
||||
|
||||
COLOR = (
|
||||
[255, 0, 0],
|
||||
[211, 0, 148],
|
||||
[0, 127, 0],
|
||||
[0, 69, 255],
|
||||
[0, 255, 0],
|
||||
[255, 0, 255],
|
||||
[0, 0, 127],
|
||||
[127, 0, 255],
|
||||
[255, 129, 0],
|
||||
[139, 139, 0],
|
||||
[255, 255, 0],
|
||||
[127, 255, 0],
|
||||
[0, 127, 255],
|
||||
[0, 255, 127],
|
||||
[255, 127, 255],
|
||||
[8, 101, 139],
|
||||
[171, 130, 255],
|
||||
[139, 112, 74],
|
||||
[205, 205, 180],
|
||||
[0, 0, 255],)
|
||||
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
PHOTO = "photo"
|
||||
RECORDING = "recording"
|
||||
|
||||
ONLINE_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'nullable': False,
|
||||
'minlength': 1,
|
||||
'maxlength': 3,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ONLINE_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
OFFLINE_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'maxlength': 3,
|
||||
'minlength': 1,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OFFLINE_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
IMAGE_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"image_urls": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'minlength': 1,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 5000
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RECORDING_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'empty': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
}
|
||||
}
|
||||
|
||||
RECORDING_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
PULL2PUSH_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"video_urls": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'nullable': False,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "pull_url",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "push_url",
|
||||
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "id",
|
||||
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
PULL2PUSH_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start", "stop"]
|
||||
},
|
||||
"video_ids": {
|
||||
'type': 'list',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from os import makedirs
|
||||
from os.path import join, exists
|
||||
from loguru import logger
|
||||
from util.RWUtils import getConfigs
|
||||
|
||||
def S(*args, **kwargs):
|
||||
"""
|
||||
将所有参数组合成字符串表示形式
|
||||
"""
|
||||
# 将位置参数转换为字符串
|
||||
args_str = [str(arg) for arg in args]
|
||||
|
||||
# 将关键字参数转换为字符串,格式为 key=value
|
||||
kwargs_str = [f"{key}={value}" for key, value in kwargs.items()]
|
||||
|
||||
# 合并所有参数并用逗号连接成一个字符串
|
||||
all_args = args_str + kwargs_str
|
||||
return " ".join(all_args)
|
||||
|
||||
# 初始化日志配置
|
||||
def init_log(base_dir, env):
|
||||
log_config = getConfigs(join(base_dir, './appIOs/conf/logger/%s_logger.yml' % env))
|
||||
# 判断日志文件是否存在,不存在创建
|
||||
base_path = join(base_dir, log_config.get("base_path"))
|
||||
if not exists(base_path):
|
||||
makedirs(base_path)
|
||||
# 移除日志设置
|
||||
logger.remove(handler_id=None)
|
||||
# 打印日志到文件
|
||||
if bool(log_config.get("enable_file_log")):
|
||||
logger.add(join(base_path, log_config.get("log_name")),
|
||||
rotation=log_config.get("rotation"),
|
||||
retention=log_config.get("retention"),
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True,
|
||||
encoding=log_config.get("encoding"))
|
||||
# 控制台输出
|
||||
if bool(log_config.get("enable_stderr")):
|
||||
logger.add(sys.stderr,
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from os import makedirs
|
||||
from os.path import join, exists
|
||||
from loguru import logger
|
||||
from json import loads
|
||||
from ruamel.yaml import safe_load
|
||||
|
||||
def getConfigs(path, read_type='yml'):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
if read_type == 'json':
|
||||
return loads(f.read())
|
||||
if read_type == 'yml':
|
||||
return safe_load(f)
|
||||
raise Exception('路径: %s未获取配置信息' % path)
|
||||
|
||||
|
||||
def readFile(file, ty="rb"):
|
||||
with open(file, ty) as f:
|
||||
return f.read()
|
||||
|
||||
# 初始化日志配置
|
||||
def init_log(base_dir, env):
|
||||
log_config = getConfigs(join(base_dir, './DrGraph/appIOs/conf/logger/%s_logger.yml' % env))
|
||||
# 判断日志文件是否存在,不存在创建
|
||||
base_path = join(base_dir, log_config.get("base_path"))
|
||||
if not exists(base_path):
|
||||
makedirs(base_path)
|
||||
# 移除日志设置
|
||||
logger.remove(handler_id=None)
|
||||
# 打印日志到文件
|
||||
if bool(log_config.get("enable_file_log")):
|
||||
logger.add(join(base_path, log_config.get("log_name")),
|
||||
rotation=log_config.get("rotation"),
|
||||
retention=log_config.get("retention"),
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True,
|
||||
encoding=log_config.get("encoding"))
|
||||
# 控制台输出
|
||||
if bool(log_config.get("enable_stderr")):
|
||||
logger.add(sys.stderr,
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True)
|
||||
|
|
@ -0,0 +1,476 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import unicodedata
|
||||
from loguru import logger
|
||||
FONT_PATH = "./DrGraph/appIOs/conf/platech.ttf"
|
||||
|
||||
zhFont = ImageFont.truetype(FONT_PATH, 20, encoding="utf-8")
|
||||
|
||||
def get_label_array(color=None, label=None, font=None, fontSize=40, unify=False):
|
||||
if unify:
|
||||
x, y, width, height = font.getbbox("标") # 统一数组大小
|
||||
else:
|
||||
x, y, width, height = font.getbbox(label)
|
||||
text_image = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
text_image = Image.fromarray(text_image)
|
||||
draw = ImageDraw.Draw(text_image)
|
||||
draw.rectangle((0, 0, width, height), fill=tuple(color))
|
||||
draw.text((0, -1), label, fill=(255, 255, 255), font=font)
|
||||
im_array = np.asarray(text_image)
|
||||
# scale = fontSize / height
|
||||
# im_array = cv2.resize(im_array, (0, 0), fx=scale, fy=scale)
|
||||
scale = height / fontSize
|
||||
im_array = cv2.resize(im_array, (0, 0), fx=scale, fy=scale)
|
||||
return im_array
|
||||
|
||||
def get_label_arrays(labelNames, colors, fontSize=40, fontPath="platech.ttf"):
|
||||
font = ImageFont.truetype(fontPath, fontSize, encoding='utf-8')
|
||||
label_arraylist = [get_label_array(colors[i % 20], label_name, font, fontSize) for i, label_name in
|
||||
enumerate(labelNames)]
|
||||
return label_arraylist
|
||||
|
||||
def get_label_array_dict(colors, fontSize=40, fontPath="platech.ttf"):
|
||||
font = ImageFont.truetype(fontPath, fontSize, encoding='utf-8')
|
||||
all_chinese_characters = []
|
||||
for char in range(0x4E00, 0x9FFF + 1): # 中文
|
||||
chinese_character = chr(char)
|
||||
if unicodedata.category(chinese_character) == 'Lo':
|
||||
all_chinese_characters.append(chinese_character)
|
||||
for char in range(0x0041, 0x005B): # 大写字母
|
||||
all_chinese_characters.append(chr(char))
|
||||
for char in range(0x0061, 0x007B): # 小写字母
|
||||
all_chinese_characters.append(chr(char))
|
||||
for char in range(0x0030, 0x003A): # 数字
|
||||
all_chinese_characters.append(chr(char))
|
||||
zh_dict = {}
|
||||
for code in all_chinese_characters:
|
||||
arr = get_label_array(colors[2], code, font, fontSize, unify=True)
|
||||
zh_dict[code] = arr
|
||||
return zh_dict
|
||||
|
||||
def get_label_left(x0,y1,label_array,img):
|
||||
imh, imw = img.shape[0:2]
|
||||
lh, lw = label_array.shape[0:2]
|
||||
# x1 框框左上x位置 + 描述的宽
|
||||
# y0 框框左上y位置 - 描述的高
|
||||
x1, y0 = x0 + lw, y1 - lh
|
||||
# 如果y0小于0, 说明超过上边框
|
||||
if y0 < 0:
|
||||
y0 = 0
|
||||
# y1等于文字高度
|
||||
y1 = y0 + lh
|
||||
# 如果y1框框的高大于图片高度
|
||||
if y1 > imh:
|
||||
# y1等于图片高度
|
||||
y1 = imh
|
||||
# y0等于y1减去文字高度
|
||||
y0 = y1 - lh
|
||||
# 如果x0小于0
|
||||
if x0 < 0:
|
||||
x0 = 0
|
||||
x1 = x0 + lw
|
||||
if x1 > imw:
|
||||
x1 = imw
|
||||
x0 = x1 - lw
|
||||
return x0,y0,x1,y1
|
||||
|
||||
def get_label_right(x1,y0,label_array):
|
||||
lh, lw = label_array.shape[0:2]
|
||||
# x1 框框右上x位置 + 描述的宽
|
||||
# y0 框框右上y位置 - 描述的高
|
||||
x0, y1 = x1 - lw, y0 - lh
|
||||
# 如果y0小于0, 说明超过上边框
|
||||
if y0 < 0 or y1 < 0:
|
||||
y1 = 0
|
||||
# y1等于文字高度
|
||||
y0 = y1 + lh
|
||||
# 如果x0小于0
|
||||
if x0 < 0 or x1 < 0:
|
||||
x0 = 0
|
||||
x1 = x0 + lw
|
||||
|
||||
return x0,y1,x1,y0
|
||||
|
||||
def xywh2xyxy(box):
|
||||
if not isinstance(box[0], (list, tuple, np.ndarray)):
|
||||
xc, yc, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
||||
bw, bh = int(w / 2), int(h / 2)
|
||||
lt, yt, rt, yr = xc - bw, yc - bh, xc + bw, yc + bh
|
||||
box = [(lt, yt), (rt, yt), (rt, yr), (lt, yr)]
|
||||
return box
|
||||
|
||||
def xywh2xyxy2(param):
|
||||
if not isinstance(param[0], (list, tuple, np.ndarray)):
|
||||
xc, yc, x2, y2 = int(param[0]), int(param[1]), int(param[2]), int(param[3])
|
||||
return [(xc, yc), (x2, yc), (x2, y2), (xc, y2)], float(param[4]), int(param[5])
|
||||
# bw, bh = int(w / 2), int(h / 2)
|
||||
# lt, yt, rt, yr = xc - bw, yc - bh, xc + bw, yc + bh
|
||||
# return [(lt, yt), (rt, yt), (rt, yr), (lt, yr)]
|
||||
return np.asarray(param[0][0:4], np.int32), float(param[1]), int(param[2])
|
||||
|
||||
def xy2xyxy(box):
|
||||
if not isinstance(box[0], (list, tuple, np.ndarray)):
|
||||
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
||||
# 顺时针
|
||||
box = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
return box
|
||||
|
||||
def draw_painting_joint(box, img, label_array, score=0.5, color=None, config=None, isNew=False, border=None):
|
||||
# 识别问题描述图片的高、宽
|
||||
# 图片的长度和宽度
|
||||
if border is not None:
|
||||
border = np.array(border,np.int32)
|
||||
color,label_array=draw_name_border(box,color,label_array,border)
|
||||
#img = draw_transparent_red_polygon(img,border,'',alpha=0.1)
|
||||
|
||||
lh, lw = label_array.shape[0:2]
|
||||
tl = config[0]
|
||||
if isinstance(box[-1], np.ndarray):
|
||||
return draw_name_points(img,box,color)
|
||||
|
||||
label = ' %.2f' % score
|
||||
box = xywh2xyxy(box)
|
||||
# 框框左上的位置
|
||||
x0, y1 = box[0][0], box[0][1]
|
||||
x0, y0, x1, y1 = get_label_left(x0, y1, label_array, img)
|
||||
# box_tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1
|
||||
'''
|
||||
1. img(array) 为ndarray类型(可以为cv.imread)直接读取的数据
|
||||
2. box(array):为所画多边形的顶点坐标
|
||||
3. 所画四边形是否闭合,通常为True
|
||||
4. color(tuple):BGR三个通道的值
|
||||
5. thickness(int):画线的粗细
|
||||
6. shift:顶点坐标中小数的位数
|
||||
'''
|
||||
img[y0:y1, x0:x1, :] = label_array
|
||||
box1 = np.asarray(box, np.int32)
|
||||
cv2.polylines(img, [box1], True, color, tl)
|
||||
pts_cls = [(x0, y0), (x1, y1)]
|
||||
# 把英文字符score画到类别旁边
|
||||
# tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1
|
||||
# tf = max(tl, 1)
|
||||
# fontScale = float(format(imw / 1920 * 1.1, '.2f')) or tl * 0.33
|
||||
# fontScale = tl * 0.33
|
||||
'''
|
||||
1. text:要计算大小的文本内容,类型为字符串。
|
||||
2. fontFace:字体类型,例如cv2.FONT_HERSHEY_SIMPLEX等。
|
||||
3. fontScale:字体大小的缩放因子,例如1.2表示字体大小增加20%。
|
||||
4. thickness:文本线条的粗细,以像素为单位。
|
||||
5. (text_width, text_height):给定文本在指定字体、字体大小、线条粗细下所占用的像素宽度和高度。
|
||||
'''
|
||||
# t_size = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
|
||||
t_size = (config[1], config[2])
|
||||
# if socre_location=='leftTop':
|
||||
p1, p2 = (pts_cls[1][0], pts_cls[0][1]), (pts_cls[1][0] + t_size[0], pts_cls[1][1])
|
||||
'''
|
||||
1. img:要绘制矩形的图像
|
||||
2. pt1:矩形框的左上角坐标,可以是一个包含两个整数的元组或列表,例如(x1, y1)或[x1, y1]。
|
||||
3. pt2:矩形框的右下角坐标,可以是一个包含两个整数的元组或列表,例如(x2, y2)或[x2, y2]。
|
||||
4. color:矩形框的颜色,可以是一个包含三个整数的元组或列表,例如(255, 0, 0)表示蓝色,或一个标量值,例如255表示白色。颜色顺序为BGR。
|
||||
5. thickness:线条的粗细,以像素为单位。如果为负值,则表示要绘制填充矩形。默认值为1。
|
||||
6. lineType:线条的类型,可以是cv2.LINE_AA表示抗锯齿线条,或cv2.LINE_4表示4连通线条,或cv2.LINE_8表示8连通线条。默认值为cv2.LINE_8。
|
||||
7. shift:坐标点小数点位数。默认值为0。
|
||||
'''
|
||||
cv2.rectangle(img, p1, p2, color, -1, cv2.LINE_AA)
|
||||
p3 = pts_cls[1][0], pts_cls[1][1] - (lh - t_size[1]) // 2
|
||||
'''
|
||||
1. img:要在其上绘制文本的图像
|
||||
2. text:要绘制的文本内容,类型为字符串
|
||||
3. org:文本起始位置的坐标,可以是一个包含两个整数的元组或列表,例如(x, y)或[x, y]。
|
||||
4. fontFace:字体类型,例如cv2.FONT_HERSHEY_SIMPLEX等。
|
||||
5. fontScale:字体大小的缩放因子,例如1.2表示字体大小增加20%。
|
||||
6. color:文本的颜色,可以是一个包含三个整数的元组或列表,例如(255, 0, 0)表示蓝色,或一个标量值,例如255表示白色。颜色顺序为BGR。
|
||||
7. thickness:文本线条的粗细,以像素为单位。默认值为1。
|
||||
8. lineType:线条的类型,可以是cv2.LINE_AA表示抗锯齿线条,或cv2.LINE_4表示4连通线条,或cv2.LINE_8表示8连通线条。默认值为cv2.LINE_8。
|
||||
9. bottomLeftOrigin:文本起始位置是否为左下角。如果为True,则文本起始位置为左下角,否则为左上角。默认值为False。
|
||||
'''
|
||||
if isNew:
|
||||
cv2.putText(img, label, p3, 0, config[3], [0, 0, 0], thickness=config[4], lineType=cv2.LINE_AA)
|
||||
else:
|
||||
cv2.putText(img, label, p3, 0, config[3], [225, 255, 255], thickness=config[4], lineType=cv2.LINE_AA)
|
||||
return img, box
|
||||
|
||||
# 动态标签
|
||||
def draw_name_joint(box, img, label_array_dict, score=0.5, color=None, config=None, name=""):
|
||||
label_array = None
|
||||
for zh in name:
|
||||
if zh in label_array_dict:
|
||||
if label_array is None:
|
||||
label_array = label_array_dict[zh]
|
||||
else:
|
||||
label_array = np.concatenate((label_array,label_array_dict[zh]), axis= 1)
|
||||
# 识别问题描述图片的高、宽
|
||||
if label_array is None:
|
||||
lh, lw = 0, 0
|
||||
else:
|
||||
lh, lw = label_array.shape[0:2]
|
||||
# 图片的长度和宽度
|
||||
imh, imw = img.shape[0:2]
|
||||
box = xywh2xyxy(box)
|
||||
# 框框左上的位置
|
||||
x0, y1 = box[0][0], box[0][1]
|
||||
x1, y0 = x0 + lw, y1 - lh
|
||||
# 如果y0小于0, 说明超过上边框
|
||||
if y0 < 0:
|
||||
y0 = 0
|
||||
# y1等于文字高度
|
||||
y1 = y0 + lh
|
||||
# 如果y1框框的高大于图片高度
|
||||
if y1 > imh:
|
||||
# y1等于图片高度
|
||||
y1 = imh
|
||||
# y0等于y1减去文字高度
|
||||
y0 = y1 - lh
|
||||
# 如果x0小于0
|
||||
if x0 < 0:
|
||||
x0 = 0
|
||||
x1 = x0 + lw
|
||||
if x1 > imw:
|
||||
x1 = imw
|
||||
x0 = x1 - lw
|
||||
tl = config[0]
|
||||
box1 = np.asarray(box, np.int32)
|
||||
cv2.polylines(img, [box1], True, color, tl)
|
||||
if label_array is not None:
|
||||
img[y0:y1, x0:x1, :] = label_array
|
||||
pts_cls = [(x0, y0), (x1, y1)]
|
||||
# 把英文字符score画到类别旁边
|
||||
# tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1
|
||||
label = ' %.2f' % score
|
||||
t_size = (config[1], config[2])
|
||||
# if socre_location=='leftTop':
|
||||
p1, p2 = (pts_cls[1][0], pts_cls[0][1]), (pts_cls[1][0] + t_size[0], pts_cls[1][1])
|
||||
cv2.rectangle(img, p1, p2, color, -1, cv2.LINE_AA)
|
||||
p3 = pts_cls[1][0], pts_cls[1][1] - (lh - t_size[1]) // 2
|
||||
cv2.putText(img, label, p3, 0, config[3], [225, 255, 255], thickness=config[4], lineType=cv2.LINE_AA)
|
||||
return img, box
|
||||
|
||||
def draw_name_ocr(box, img, color, line_thickness=2, outfontsize=40):
|
||||
font = ImageFont.truetype(FONT_PATH, outfontsize, encoding='utf-8')
|
||||
# (color=None, label=None, font=None, fontSize=40, unify=False)
|
||||
label_zh = get_label_array(color, box[0], font, outfontsize)
|
||||
return plot_one_box_auto(box[1], img, color, line_thickness, label_zh)
|
||||
def filterBox(det0, det1, pix_dis):
|
||||
# det0为 (m1, 11) 矩阵
|
||||
# det1为 (m2, 12) 矩阵
|
||||
if len(det0.shape) == 1:
|
||||
det0 = det0[np.newaxis,...]
|
||||
if len(det1.shape) == 1:
|
||||
det1 = det1[np.newaxis,...]
|
||||
det1 = det1[...,0:11].copy()
|
||||
m, n = det0.size, det1.size
|
||||
if not m:
|
||||
return det0
|
||||
# 在det0的列方向加一个元素flag代表该目标框中心点是否在之前目标框内(0代表不在,其他代表在)
|
||||
flag = np.zeros([len(det0), 1])
|
||||
det0 = np.concatenate([det0, flag], axis=1)
|
||||
det0_copy = det0.copy()
|
||||
# det1_copy = det1.copy()
|
||||
if not n:
|
||||
return det0
|
||||
# det0转成 (m1, m2, 12) 的矩阵
|
||||
# det1转成 (m1, m2, 12) 的矩阵
|
||||
# det0与det1在第3维方向上拼接(6 + 7 = 13)
|
||||
det0 = det0[:, np.newaxis, :].repeat(det1.shape[0], 1)
|
||||
det1 = det1[np.newaxis, ...].repeat(det0.shape[0], 0)
|
||||
joint_det = np.concatenate((det1, det0), axis=2)
|
||||
# 分别求det0和det1的x1, y1, x2, y2(水平框的左上右下角点)
|
||||
x1, y1, x2, y2 = joint_det[..., 0], joint_det[..., 1], joint_det[..., 4], joint_det[..., 5]
|
||||
x3, y3, x4, y4 = joint_det[..., 11], joint_det[..., 12], joint_det[..., 15], joint_det[..., 16]
|
||||
|
||||
x2_c, y2_c = (x1+x2)//2, (y1+y2)//2
|
||||
x_c, y_c = (x3+x4)//2, (y3+y4)//2
|
||||
dis = (x2_c - x_c)**2 + (y2_c - y_c)**2
|
||||
mask = (joint_det[..., 9] == joint_det[..., 20]) & (dis <= pix_dis**2)
|
||||
|
||||
# 类别相同 & 中心点在上一帧的框内 判断为True
|
||||
res = np.sum(mask, axis=1)
|
||||
det0_copy[..., -1] = res
|
||||
return det0_copy
|
||||
|
||||
def plot_one_box_auto(box, img, color=None, line_thickness=2, label_array=None):
|
||||
# print("省略 :%s, box:%s"%('+++' * 10, box))
|
||||
# 识别问题描述图片的高、宽
|
||||
lh, lw = label_array.shape[0:2]
|
||||
# print("省略 :%s, lh:%s, lw:%s"%('+++' * 10, lh, lw))
|
||||
# 图片的长度和宽度
|
||||
imh, imw = img.shape[0:2]
|
||||
points = None
|
||||
box = xy2xyxy(box)
|
||||
# 框框左上的位置
|
||||
x0, y1 = box[0][0], box[0][1]
|
||||
# print("省略 :%s, x0:%s, y1:%s"%('+++' * 10, x0, y1))
|
||||
x1, y0 = x0 + lw, y1 - lh
|
||||
# 如果y0小于0, 说明超过上边框
|
||||
if y0 < 0:
|
||||
y0 = 0
|
||||
# y1等于文字高度
|
||||
y1 = y0 + lh
|
||||
# 如果y1框框的高大于图片高度
|
||||
if y1 > imh:
|
||||
# y1等于图片高度
|
||||
y1 = imh
|
||||
# y0等于y1减去文字高度
|
||||
y0 = y1 - lh
|
||||
# 如果x0小于0
|
||||
if x0 < 0:
|
||||
x0 = 0
|
||||
x1 = x0 + lw
|
||||
if x1 > imw:
|
||||
x1 = imw
|
||||
x0 = x1 - lw
|
||||
# box_tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1
|
||||
'''
|
||||
1. img(array) 为ndarray类型(可以为cv.imread)直接读取的数据
|
||||
2. box(array):为所画多边形的顶点坐标
|
||||
3. 所画四边形是否闭合,通常为True
|
||||
4. color(tuple):BGR三个通道的值
|
||||
5. thickness(int):画线的粗细
|
||||
6. shift:顶点坐标中小数的位数
|
||||
'''
|
||||
# Plots one bounding box on image img
|
||||
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
box1 = np.asarray(box, np.int32)
|
||||
cv2.polylines(img, [box1], True, color, tl)
|
||||
img[y0:y1, x0:x1, :] = label_array
|
||||
|
||||
return img, box
|
||||
|
||||
def draw_name_crowd(dets, img, color, outfontsize=20):
|
||||
font = ImageFont.truetype(FONT_PATH, outfontsize, encoding='utf-8')
|
||||
if len(dets) == 2:
|
||||
label = '当前人数:%d'%len(dets[0])
|
||||
detP = dets[0]
|
||||
line = dets[1]
|
||||
for p in detP:
|
||||
img = cv2.circle(img, (int(p[0]), int(p[1])), line, color, -1)
|
||||
label_arr = get_label_array(color, label, font, outfontsize)
|
||||
lh, lw = label_arr.shape[0:2]
|
||||
img[0:lh, 0:lw, :] = label_arr
|
||||
elif len(dets) == 3:
|
||||
detP = dets[1]
|
||||
line = dets[2]
|
||||
for p in detP:
|
||||
img = cv2.circle(img, (int(p[0]), int(p[1])), line, color, -1)
|
||||
|
||||
detM = dets[0]
|
||||
h, w = img.shape[:2]
|
||||
for b in detM:
|
||||
label = '该建筑下行人及数量:%d'%(int(b[4]))
|
||||
label_arr = get_label_array(color, label, font, outfontsize)
|
||||
lh, lw = label_arr.shape[0:2]
|
||||
# 框框左上的位置
|
||||
x0, y1 = int(b[0]), int(b[1])
|
||||
# print("省略 :%s, x0:%s, y1:%s"%('+++' * 10, x0, y1))
|
||||
x1, y0 = x0 + lw, y1 - lh
|
||||
# 如果y0小于0, 说明超过上边框
|
||||
if y0 < 0:
|
||||
y0 = 0
|
||||
# y1等于文字高度
|
||||
y1 = y0 + lh
|
||||
# 如果y1框框的高大于图片高度
|
||||
if y1 > h:
|
||||
# y1等于图片高度
|
||||
y1 = h
|
||||
# y0等于y1减去文字高度
|
||||
y0 = y1 - lh
|
||||
# 如果x0小于0
|
||||
if x0 < 0:
|
||||
x0 = 0
|
||||
x1 = x0 + lw
|
||||
if x1 > w:
|
||||
x1 = w
|
||||
x0 = x1 - lw
|
||||
|
||||
cv2.polylines(img, [np.asarray(xy2xyxy(b), np.int32)], True, (0, 128, 255), 2)
|
||||
img[y0:y1, x0:x1, :] = label_arr
|
||||
|
||||
|
||||
return img, dets
|
||||
|
||||
def draw_name_points(img,box,color):
|
||||
font = ImageFont.truetype(FONT_PATH, 6, encoding='utf-8')
|
||||
points = box[-1]
|
||||
arrea = cv2.contourArea(points)
|
||||
label = '火焰'
|
||||
arealabel = '面积:%s' % f"{arrea:.1e}"
|
||||
label_array_area = get_label_array(color, arealabel, font, 10)
|
||||
label_array = get_label_array(color, label, font, 10)
|
||||
lh_area, lw_area = label_array_area.shape[0:2]
|
||||
box = box[:4]
|
||||
# 框框左上的位置
|
||||
x0, y1 = box[0][0], max(box[0][1] - lh_area - 3, 0)
|
||||
x1, y0 = box[1][0], box[1][1]
|
||||
x0_label, y0_label, x1_label, y1_label = get_label_left(x0, y1, label_array, img)
|
||||
x0_area, y0_area, x1_area, y1_area = get_label_right(x1, y0, label_array_area)
|
||||
img[y0_label:y1_label, x0_label:x1_label, :] = label_array
|
||||
img[y0_area:y1_area, x0_area:x1_area, :] = label_array_area
|
||||
# cv2.drawContours(img, points, -1, color, tl)
|
||||
cv2.polylines(img, [points], False, color, 2)
|
||||
if lw_area < box[1][0] - box[0][0]:
|
||||
box = [(x0, y1), (x1, y1), (x1, box[2][1]), (x0, box[2][1])]
|
||||
else:
|
||||
box = [(x0_label, y1), (x1, y1), (x1, box[2][1]), (x0_label, box[2][1])]
|
||||
box = np.asarray(box, np.int32)
|
||||
cv2.polylines(img, [box], True, color, 2)
|
||||
return img, box
|
||||
|
||||
def draw_name_border(box,color,label_array,border):
|
||||
box = xywh2xyxy(box[:4])
|
||||
cx, cy = int((box[0][0] + box[2][0]) / 2), int((box[0][1] + box[2][1]) / 2)
|
||||
flag = cv2.pointPolygonTest(border, (int(cx), int(cy)),
|
||||
False) # 若为False,会找点是否在内,外,或轮廓上
|
||||
if flag == 1:
|
||||
color = [0, 0, 255]
|
||||
# 纯白色是(255, 255, 255),根据容差定义白色范围
|
||||
lower_white = np.array([255 - 30] * 3, dtype=np.uint8)
|
||||
upper_white = np.array([255, 255, 255], dtype=np.uint8)
|
||||
# 创建白色区域的掩码(白色区域为True,非白色为False)
|
||||
white_mask = cv2.inRange(label_array, lower_white, upper_white)
|
||||
# 创建与原图相同大小的目标颜色图像
|
||||
target_img = np.full_like(label_array, color, dtype=np.uint8)
|
||||
# 先将非白色区域设为目标颜色,再将白色区域覆盖回原图颜色
|
||||
label_array = np.where(white_mask[..., None], label_array, target_img)
|
||||
return color,label_array
|
||||
|
||||
def draw_transparent_red_polygon(img, points, alpha=0.5):
|
||||
"""
|
||||
在图像中指定的多边形区域绘制半透明红色
|
||||
|
||||
参数:
|
||||
image_path: 原始图像路径
|
||||
points: 多边形顶点坐标列表,格式为[(x1,y1), (x2,y2), ..., (xn,yn)]
|
||||
output_path: 输出图像路径
|
||||
alpha: 透明度系数,0-1之间,值越小透明度越高
|
||||
"""
|
||||
# 读取原始图像
|
||||
if img is None:
|
||||
raise ValueError(f"无法读取图像")
|
||||
|
||||
# 创建与原图大小相同的透明图层(RGBA格式)
|
||||
overlay = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
||||
|
||||
# 将点列表转换为适合cv2.fillPoly的格式
|
||||
#pts = np.array(points, np.int32)
|
||||
pts = points.reshape((-1, 1, 2))
|
||||
|
||||
# 在透明图层上绘制红色多边形(BGR为0,0,255)
|
||||
# 最后一个通道是Alpha值,控制透明度,黄色rgb
|
||||
cv2.fillPoly(overlay, [pts], (255, 0, 0, int(alpha * 255)))
|
||||
|
||||
# 将透明图层转换为BGR格式(用于与原图混合)
|
||||
overlay_bgr = cv2.cvtColor(overlay, cv2.COLOR_RGBA2BGR)
|
||||
|
||||
# 创建掩码,用于提取红色区域
|
||||
mask = overlay[:, :, 3] / 255.0
|
||||
mask = np.stack([mask] * 3, axis=-1) # 转换为3通道
|
||||
|
||||
# 混合原图和透明红色区域
|
||||
img = img * (1 - mask) + overlay_bgr * mask
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
# # 保存结果
|
||||
# cv2.imwrite(output_path, result)
|
||||
|
||||
return img
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from json import loads
|
||||
|
||||
from ruamel.yaml import safe_load
|
||||
|
||||
|
||||
def getConfigs(path, read_type='yml'):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
if read_type == 'json':
|
||||
return loads(f.read())
|
||||
if read_type == 'yml':
|
||||
return safe_load(f)
|
||||
raise Exception('路径: %s未获取配置信息' % path)
|
||||
|
||||
|
||||
def readFile(file, ty="rb"):
|
||||
with open(file, ty) as f:
|
||||
return f.read()
|
||||
|
||||
|
|
@ -0,0 +1,817 @@
|
|||
from loguru import logger
|
||||
import json, cv2, time, os, torch, glob
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from scipy import interpolate
|
||||
|
||||
from DrGraph.util import yoloHelper, torchHelper
|
||||
from DrGraph.util.drHelper import *
|
||||
|
||||
from DrGraph.util.segutils.trtUtils import segtrtEval,yolov5Trtforward,OcrTrtForward
|
||||
|
||||
def getDetectionsFromPreds(pred,img,im0,conf_thres=0.2,iou_thres=0.45,ovlap_thres=0.6,padInfos=None):
|
||||
'''
|
||||
对YOLO模型的预测结果进行后处理,包括NMS、坐标还原和格式转换等操作。
|
||||
|
||||
参数:
|
||||
pred (torch.Tensor): 检测模型输出的结果,通常是包含边界框、置信度和类别信息的张量。
|
||||
img (torch.Tensor): 输入检测模型时的图像张量,用于坐标变换参考。
|
||||
im0 (numpy.ndarray): 原始输入图像,用于将检测框映射回原始尺寸。
|
||||
conf_thres (float): 第一次非极大值抑制(NMS)中置信度的阈值,默认为0.2。
|
||||
iou_thres (float): 第一次非极大值抑制中IoU的阈值,默认为0.45。
|
||||
ovlap_thres (float): 可选的二次NMS中IoU的阈值,若为0则不执行,默认为0.6。
|
||||
padInfos (list or None): 图像resize时的填充信息,用于准确还原检测框位置。
|
||||
|
||||
返回:
|
||||
list: 包含以下内容的列表:
|
||||
- img (numpy.ndarray): 原始图像。
|
||||
- im0 (numpy.ndarray): 同上,重复项以保持接口一致性。
|
||||
- det_xywh (list of lists): 检测结果列表,每个元素为 [x0, y0, x1, y1, score, cls]。
|
||||
- 0 (int): 无实际意义,仅为兼容旧接口保留。
|
||||
'''
|
||||
with TimeDebugger('预测结果后处理') as td:
|
||||
# 执行第一次非极大值抑制(NMS),过滤低置信度和重叠的检测框
|
||||
pred = yoloHelper.non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
|
||||
# 如果设置了二次NMS阈值,则执行重叠框抑制
|
||||
if ovlap_thres:
|
||||
pred = yoloHelper.overlap_box_suppression(pred, ovlap_thres)
|
||||
td.addStep("NMS")
|
||||
i=0;det=pred[0]###一次检测一张图片
|
||||
det_xywh=[]
|
||||
|
||||
# 如果存在检测结果,则进行坐标还原和格式转换
|
||||
if len(det)>0:
|
||||
#将坐标恢复成原始尺寸的大小
|
||||
H,W = im0.shape[0:2]
|
||||
det[:, :4] = imgHelper.scale_back( det[:, :4],padInfos).round() \
|
||||
if padInfos \
|
||||
else imgHelper.scale_coords(img.shape[2:], det[:, :4],im0.shape).round()
|
||||
|
||||
#转换坐标格式,及tensor转换为cpu中的numpy格式。
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
cls_c = cls.cpu().numpy()
|
||||
conf_c = conf.cpu().numpy()
|
||||
tt=[ int(x.cpu()) for x in xyxy]
|
||||
x0,y0,x1,y1 = tt[0:4]
|
||||
x0 = max(0,x0);y0 = max(0,y0);
|
||||
x1 = min(W-1,x1);y1 = min(H-1,y1)
|
||||
#line = [float(cls_c), *tt, float(conf_c)] # label format ,
|
||||
line = [ x0,y0,x1,y1, float(conf_c),float(cls_c)] # label format 2023.08.03--修改
|
||||
#print('###line305:',line)
|
||||
det_xywh.append(line)
|
||||
|
||||
td.addStep('ScaleBack')
|
||||
return [im0,im0,det_xywh,0] ###0,没有意义,只是为了和过去保持一致长度4个元素。
|
||||
|
||||
def score_filter_byClass(pdetections,score_para_2nd):
|
||||
"""
|
||||
根据类别特定的置信度阈值过滤检测结果
|
||||
|
||||
参数:
|
||||
pdetections: 检测结果列表,每个元素包含[x1, y1, x2, y2, score, class]格式的检测框信息
|
||||
score_para_2nd: 字典类型,键为类别标识(整数或字符串),值为对应的置信度阈值
|
||||
|
||||
返回值:
|
||||
ret: 过滤后的检测结果列表,只保留置信度高于对应类别阈值的检测框
|
||||
"""
|
||||
ret=[]
|
||||
for det in pdetections:
|
||||
# 获取当前检测框的置信度和类别
|
||||
score,cls = det[4],det[5]
|
||||
# 根据类别查找对应的置信度阈值,优先查找整数键,其次查找字符串键,都没有则使用默认阈值0.7
|
||||
if int(cls) in score_para_2nd.keys():
|
||||
score_th = score_para_2nd[int(cls)]
|
||||
elif str(int(cls)) in score_para_2nd.keys():
|
||||
score_th = score_para_2nd[str(int(cls))]
|
||||
else:
|
||||
score_th = 0.7
|
||||
# 只保留置信度高于阈值的检测框
|
||||
if score > score_th:
|
||||
ret.append(det)
|
||||
return ret
|
||||
|
||||
def AI_process(im0s, model, segmodel, names, label_arraylist, rainbows,
|
||||
objectPar={ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False,'score_byClass':{x:0.1 for x in range(30)} },
|
||||
font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3},
|
||||
segPar={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True},
|
||||
mode='others', postPar=None):
|
||||
# logger.info("AI_process(\n\rim0s={}, \n\rmodel={},\n\rsegmodel={},\n\rnames={},\n\rrainbows={},\n\robjectPar={},\n\rfont={},\n\rsegPar={},\n\rmode={},\n\rpostPar={})", \
|
||||
# im0s, model, segmodel, names, rainbows, \
|
||||
# objectPar, font, segPar, mode, postPar)
|
||||
"""
|
||||
对输入图像进行目标检测和分割处理,返回处理后的图像及检测结果。
|
||||
|
||||
参数:
|
||||
im0s (list): 原始图像列表。
|
||||
model: 检测模型对象。
|
||||
segmodel: 分割模型对象,若未使用则为 None。
|
||||
names (list): 类别名称列表。
|
||||
label_arraylist: 标签数组列表。
|
||||
rainbows: 颜色映射相关参数。
|
||||
objectPar (dict): 目标检测相关参数配置,默认包含:
|
||||
- half (bool): 是否使用半精度(FP16)。
|
||||
- device (str): 使用的设备(如 'cuda:0')。
|
||||
- conf_thres (float): 置信度阈值。
|
||||
- iou_thres (float): IOU 阈值。
|
||||
- allowedList (list): 允许检测的类别列表。
|
||||
- segRegionCnt (int): 分割区域数量。
|
||||
- trtFlag_det (bool): 是否使用 TensorRT 加速检测。
|
||||
- trtFlag_seg (bool): 是否使用 TensorRT 加速分割。
|
||||
- score_byClass (dict): 每个类别的最低置信度阈值。
|
||||
font (dict): 字体和绘制相关参数配置。
|
||||
segPar (dict): 分割模型相关参数配置。
|
||||
mode (str): 处理模式标识。
|
||||
postPar: 后处理参数,当前未使用。
|
||||
|
||||
返回:
|
||||
tuple: 包含两个元素的元组:
|
||||
- list: 处理结果列表,格式为 [原始图像, 处理后图像, 检测框信息, 帧号]。
|
||||
其中检测框信息是一个列表,每个元素表示一个目标,格式为:
|
||||
[xc, yc, w, h, conf_c, cls_c],
|
||||
xc, yc 为中心坐标,w, h 为目标宽高,conf_c 为置信度,cls_c 为类别编号。
|
||||
- str: 各阶段处理耗时信息字符串。
|
||||
"""
|
||||
|
||||
# 从 objectPar 中提取关键参数
|
||||
half,device,conf_thres,iou_thres = objectPar['half'],objectPar['device'],objectPar['conf_thres'],objectPar['iou_thres']
|
||||
|
||||
trtFlag_det,trtFlag_seg,segRegionCnt = objectPar['trtFlag_det'],objectPar['trtFlag_seg'],objectPar['segRegionCnt']
|
||||
if 'ovlap_thres_crossCategory' in objectPar.keys(): ovlap_thres = objectPar['ovlap_thres_crossCategory']
|
||||
else: ovlap_thres = None
|
||||
|
||||
if 'score_byClass' in objectPar.keys(): score_byClass = objectPar['score_byClass']
|
||||
else: score_byClass = None
|
||||
|
||||
with TimeDebugger('AI_process') as td: # enabled logAtExit - 结束时输出用时分析日志
|
||||
# 图像预处理:根据是否使用 TensorRT 进行不同的图像填充或 letterbox 操作
|
||||
if trtFlag_det:
|
||||
img, padInfos = imgHelper.img_pad(im0s[0], size=(640,640,3))
|
||||
img = [img]
|
||||
else:
|
||||
#print('####line72:',im0s[0][10:12,10:12,2])
|
||||
img = [imgHelper.letterbox(x, 640, auto=True, stride=32)[0] for x in im0s]
|
||||
padInfos=None
|
||||
img_height, img_width = img[0].shape[0:2] # 获取高和宽
|
||||
#print('####line74:',img[0][10:12,10:12,2])
|
||||
# 将图像堆叠并转换为模型输入格式(BGR 转 RGB,HWC 转 CHW)
|
||||
img = np.stack(img, 0)
|
||||
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
|
||||
img = np.ascontiguousarray(img)
|
||||
td.addStep("img_pad" if trtFlag_det else "letterbox")
|
||||
|
||||
# 转换为 PyTorch 张量并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img)
|
||||
td.addStep(f"from_numpy({img_height} x {img_width})")
|
||||
img = img.to(device)
|
||||
td.addStep(f"to GPU({img_height} x {img_width})" )
|
||||
|
||||
img = img.half() if half else img.float() # uint8 to fp16/32
|
||||
img /= 255.0
|
||||
# td.addStep("seg")
|
||||
|
||||
# 如果提供了分割模型,则执行分割推理
|
||||
if segmodel:
|
||||
seg_pred,segstr = segmodel.eval(im0s[0] )
|
||||
segFlag=True
|
||||
else:
|
||||
seg_pred = None;segFlag=False;segstr='Not implemented'
|
||||
td.addStep("infer")
|
||||
# 执行目标检测推理
|
||||
if trtFlag_det:
|
||||
pred = yolov5Trtforward(model,img)
|
||||
else:
|
||||
#print('####line96:',img[0,0,10:12,10:12])
|
||||
pred = model(img,augment=False)[0]
|
||||
td.addStep('yolov5Trtforward' if trtFlag_det else 'model')
|
||||
|
||||
# 对检测结果进行后处理,包括 NMS 和坐标还原
|
||||
p_result = getDetectionsFromPreds(pred,img,im0s[0],conf_thres=conf_thres,iou_thres=iou_thres,ovlap_thres=ovlap_thres,padInfos=padInfos)
|
||||
# 根据类别分别设置置信度阈值过滤
|
||||
if score_byClass:
|
||||
p_result[2] = score_filter_byClass(p_result[2],score_byClass)
|
||||
td.addStep('后处理')
|
||||
#print('-'*10,p_result[2])
|
||||
#if mode=='highWay3.0':
|
||||
#if segmodel:
|
||||
# 如果启用了混合后处理函数(如结合分割结果优化检测框),则执行该函数
|
||||
if segPar and segPar['mixFunction']['function']:
|
||||
mixFunction = segPar['mixFunction']['function'];
|
||||
H,W = im0s[0].shape[0:2]
|
||||
parMix = segPar['mixFunction']['pars'];#print('###line117:',parMix,p_result[2])
|
||||
parMix['imgSize'] = (W,H)
|
||||
#print(' -----------line149: ',p_result[2] ,'\n', seg_pred, parMix ,' sumpSeg:',np.sum(seg_pred))
|
||||
logger.warning('启用混合后处理函数')
|
||||
p_result[2] , timeMixPost = mixFunction(p_result[2], seg_pred, pars=parMix )
|
||||
#print(' -----------line112: ',p_result[2] )
|
||||
p_result.append(seg_pred)
|
||||
|
||||
else:
|
||||
timeMixPost=':0 ms'
|
||||
time_info = td.getReportInfo()
|
||||
return p_result,time_info
|
||||
|
||||
def AI_process_N(im0s,modelList,postProcess):
|
||||
|
||||
#输入参数
|
||||
## im0s---原始图像列表
|
||||
## modelList--所有的模型
|
||||
# postProcess--字典{},包括后处理函数,及其参数
|
||||
#输出参数
|
||||
##ret[0]--检测结果;
|
||||
##ret[1]--时间信息
|
||||
|
||||
#modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果
|
||||
modelRets=[ model.eval(im0s[0]) for model in modelList]
|
||||
|
||||
timeInfos = [ x[1] for x in modelRets]
|
||||
timeInfos=''.join(timeInfos)
|
||||
timeInfos=timeInfos
|
||||
|
||||
#postProcess['function']--后处理函数,输入的就是所有模型输出结果
|
||||
mixFunction =postProcess['function']
|
||||
predsList = [ modelRet[0] for modelRet in modelRets ]
|
||||
H,W = im0s[0].shape[0:2]
|
||||
postProcess['pars']['imgSize'] = (W,H)
|
||||
|
||||
#ret就是混合处理后的结果
|
||||
ret = mixFunction( predsList, postProcess['pars'])
|
||||
|
||||
return ret[0],timeInfos+ret[1]
|
||||
|
||||
def getMaxScoreWords(detRets0):
|
||||
maxScore=-1;maxId=0
|
||||
for i,detRet in enumerate(detRets0):
|
||||
if detRet[4]>maxScore:
|
||||
maxId=i
|
||||
maxScore = detRet[4]
|
||||
return maxId
|
||||
|
||||
def AI_process_C(im0s,modelList,postProcess):
|
||||
#函数定制的原因:
|
||||
## 之前模型处理流是
|
||||
## 图片---> 模型1-->result1;图片---> 模型2->result2;[result1,result2]--->后处理函数
|
||||
## 本函数的处理流程是
|
||||
## 图片---> 模型1-->result1;[图片,result1]---> 模型2->result2;[result1,result2]--->后处理函数
|
||||
## 模型2的输入,是有模型1的输出决定的。如模型2是ocr模型,需要将模型1检测出来的船名抠图出来输入到模型2.
|
||||
## 之前的模型流都是模型2是分割模型,输入就是原始图片,与模型1的输出无关。
|
||||
#输入参数
|
||||
## im0s---原始图像列表
|
||||
## modelList--所有的模型
|
||||
# postProcess--字典{},包括后处理函数,及其参数
|
||||
#输出参数
|
||||
##ret[0]--检测结果;
|
||||
##ret[1]--时间信息
|
||||
|
||||
#modelList包括模型,每个模型是一个类,里面的eval函数可以输出该模型的推理结果
|
||||
|
||||
t0=time.time()
|
||||
detRets0 = modelList[0].eval(im0s[0])
|
||||
|
||||
#detRets0=[[12, 46, 1127, 1544, 0.2340087890625, 2.0], [1884, 1248, 2992, 1485, 0.64208984375, 1.0]]
|
||||
detRets0 = detRets0[0]
|
||||
parsIn=postProcess['pars']
|
||||
|
||||
_detRets0_obj = list(filter(lambda x: x[5] in parsIn['objs'], detRets0 ))
|
||||
_detRets0_others = list(filter(lambda x: x[5] not in parsIn['objs'], detRets0 ))
|
||||
_detRets0 = []
|
||||
if postProcess['name']=='channel2':
|
||||
if len(_detRets0_obj)>0:
|
||||
maxId=getMaxScoreWords(_detRets0_obj)
|
||||
_detRets0 = _detRets0_obj[maxId:maxId+1]
|
||||
else: _detRets0 = detRets0
|
||||
|
||||
|
||||
t1=time.time()
|
||||
imagePatches = [ im0s[0][int(x[1]):int(x[3] ) ,int(x[0]):int(x[2])] for x in _detRets0 ]
|
||||
detRets1 = [modelList[1].eval(patch) for patch in imagePatches]
|
||||
print('###line240:',detRets1)
|
||||
if postProcess['name']=='crackMeasurement':
|
||||
detRets1 = [x[0]*255 for x in detRets1]
|
||||
t2=time.time()
|
||||
mixFunction =postProcess['function']
|
||||
crackInfos = [mixFunction(patchMask,par=parsIn) for patchMask in detRets1]
|
||||
|
||||
rets = [ _detRets0[i]+ crackInfos[i] for i in range(len(imagePatches)) ]
|
||||
t3=time.time()
|
||||
outInfos='total:%.1f (det:%.1f %d次segs:%.1f mixProcess:%.1f) '%( (t3-t0)*1000, (t1-t0)*1000, len(detRets1),(t2-t1)*1000, (t3-t2)*1000 )
|
||||
elif postProcess['name']=='channel2':
|
||||
H,W = im0s[0].shape[0:2];parsIn['imgSize'] = (W,H)
|
||||
mixFunction =postProcess['function']
|
||||
_detRets0_others = mixFunction([_detRets0_others], parsIn)
|
||||
ocrInfo='no ocr'
|
||||
if len(_detRets0_obj)>0:
|
||||
res_real = detRets1[0][0]
|
||||
res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
|
||||
|
||||
#detRets1[0][0]="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),detRets1[0][0])))
|
||||
_detRets0_obj[maxId].append(res_real )
|
||||
_detRets0_obj = [_detRets0_obj[maxId]]##只输出有OCR的那个船名结果
|
||||
ocrInfo=detRets1[0][1]
|
||||
print( ' _detRets0_obj:{} _detRets0_others:{} '.format( _detRets0_obj, _detRets0_others ) )
|
||||
rets=_detRets0_obj+_detRets0_others
|
||||
t3=time.time()
|
||||
outInfos='total:%.1f ,where det:%.1f, ocr:%s'%( (t3-t0)*1000, (t1-t0)*1000, ocrInfo)
|
||||
|
||||
#print('###line233:',detRets1,detRets0 )
|
||||
|
||||
return rets,outInfos
|
||||
|
||||
def post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe,ObjectPar={ 'object_config':[0,1,2,3,4], 'slopeIndex':[5,6,7] ,'segmodel':True,'segRegionCnt':1 },font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3},padInfos=None ,ovlap_thres=None):
|
||||
object_config,slopeIndex,segmodel,segRegionCnt=ObjectPar['object_config'],ObjectPar['slopeIndex'],ObjectPar['segmodel'],ObjectPar['segRegionCnt']
|
||||
##输入dataset genereate 生成的数据,model预测的结果pred,nms参数
|
||||
##主要操作NMS ---> 坐标转换 ---> 画图
|
||||
##输出原图、AI处理后的图、检测结果
|
||||
time0=time.time()
|
||||
path, img, im0s, vid_cap ,pred,seg_pred= datas[0:6];
|
||||
#segmodel=True
|
||||
pred = yoloHelper.non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
|
||||
if ovlap_thres:
|
||||
pred = yoloHelper.overlap_box_suppression(pred, ovlap_thres)
|
||||
time1=time.time()
|
||||
i=0;det=pred[0]###一次检测一张图片
|
||||
time1_1 = time.time()
|
||||
#p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
|
||||
p, s, im0 = path[i], '%g: ' % i, im0s[i]
|
||||
time1_2 = time.time()
|
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
time1_3 = time.time()
|
||||
det_xywh=[];
|
||||
#im0_brg=cv2.cvtColor(im0,cv2.COLOR_RGB2BGR);
|
||||
if segmodel:
|
||||
if len(seg_pred)==2:
|
||||
im0,water = illBuildings(seg_pred,im0)
|
||||
else:
|
||||
river={ 'color':font['waterLineColor'],'line_width':font['waterLineWidth'],'segRegionCnt':segRegionCnt,'segLineShow':font['segLineShow'] }
|
||||
im0,water = drawWater(seg_pred,im0,river)
|
||||
time2=time.time()
|
||||
#plt.imshow(im0);plt.show()
|
||||
if len(det)>0:
|
||||
# Rescale boxes from img_size to im0 size
|
||||
if not padInfos:
|
||||
det[:, :4] = imgHelper.scale_coords(img.shape[2:], det[:, :4],im0.shape).round()
|
||||
else:
|
||||
#print('####line131:',det[:, :])
|
||||
det[:, :4] = imgHelper.scale_back( det[:, :4],padInfos).round()
|
||||
#print('####line133:',det[:, :])
|
||||
#用seg模型,确定有效检测匡及河道轮廓线
|
||||
if segmodel:
|
||||
cls_indexs = det[:, 5].clone().cpu().numpy().astype(np.int32)
|
||||
##判断哪些目标属于岸坡的
|
||||
slope_flag = np.array([x in slopeIndex for x in cls_indexs ] )
|
||||
|
||||
det_c = det.clone(); det_c=det_c.cpu().numpy()
|
||||
try:
|
||||
area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )*1.0/(1.0*(x[2]-x[0])*(x[3]-x[1])+0.00001) for x in det_c] )
|
||||
except:
|
||||
print('*****************************line143: error:',det_c)
|
||||
water_flag = np.array(area_factors>0.1)
|
||||
det = det[water_flag|slope_flag]##如果是水上目标,则需要与水的iou超过0.1;如果是岸坡目标,则直接保留。
|
||||
#对检测匡绘图
|
||||
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
xywh = (mathHelper.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
cls_c = cls.cpu().numpy()
|
||||
|
||||
|
||||
conf_c = conf.cpu().numpy()
|
||||
tt=[ int(x.cpu()) for x in xyxy]
|
||||
#line = [float(cls_c), *tt, float(conf_c)] # label format
|
||||
line = [*tt, float(conf_c), float(cls_c)] # label format
|
||||
det_xywh.append(line)
|
||||
label = f'{names[int(cls)]} {conf:.2f}'
|
||||
#print('- '*20, ' line165:',xyxy,cls,conf )
|
||||
if int(cls_c) not in object_config: ###如果不是所需要的目标,则不显示
|
||||
continue
|
||||
#print('- '*20, ' line168:',xyxy,cls,conf )
|
||||
im0 = drawHelper.draw_painting_joint(xyxy,im0,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],font=font)
|
||||
time3=time.time()
|
||||
strout='nms:%s drawWater:%s,copy:%s,toTensor:%s,detDraw:%s '% ( \
|
||||
timeHelper.deltaTime_MS(time0,time1),
|
||||
timeHelper.deltaTime_MS(time1,time2),
|
||||
timeHelper.deltaTime_MS(time1_1,time1_2),
|
||||
timeHelper.deltaTime_MS(time1_2,time1_3),
|
||||
timeHelper.deltaTime_MS(time2,time3) )
|
||||
return [im0s[0],im0,det_xywh,iframe],strout
|
||||
|
||||
def AI_process_forest(im0s,model,segmodel,names,label_arraylist,rainbows,half=True,device=' cuda:0',conf_thres=0.25, iou_thres=0.45,
|
||||
allowedList=[0,1,2,3], font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3} ,trtFlag_det=False,SecNms=None):
|
||||
#输入参数
|
||||
# im0s---原始图像列表
|
||||
# model---检测模型,segmodel---分割模型(如若没有用到,则为None)
|
||||
#输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
|
||||
# [im0s[0],im0,det_xywh,iframe]中,
|
||||
# im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
|
||||
# det_xywh--检测结果,是一个列表。
|
||||
# 其中每一个元素表示一个目标构成如:[ xc,yc,w,h, float(conf_c),float(cls_c)],#2023.08.03,修改输出格式
|
||||
# #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
|
||||
# #strout---统计AI处理个环节的时间
|
||||
|
||||
# Letterbox
|
||||
time0=time.time()
|
||||
if trtFlag_det:
|
||||
img, padInfos = imgHelper.img_pad(im0s[0], size=(640,640,3)) ;img = [img]
|
||||
else:
|
||||
img = [imgHelper.letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
|
||||
#img = [letterbox(x, 640, auto=True, stride=32)[0] for x in im0s]
|
||||
# Stack
|
||||
img = np.stack(img, 0)
|
||||
# Convert
|
||||
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
img = torch.from_numpy(img).to(device)
|
||||
img = img.half() if half else img.float() # uint8 to fp16/32
|
||||
|
||||
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
if segmodel:
|
||||
seg_pred,segstr = segmodel.eval(im0s[0] )
|
||||
segFlag=True
|
||||
else:
|
||||
seg_pred = None;segFlag=False
|
||||
time1=time.time()
|
||||
pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0]
|
||||
|
||||
|
||||
time2=time.time()
|
||||
datas = [[''], img, im0s, None,pred,seg_pred,10]
|
||||
|
||||
ObjectPar={ 'object_config':allowedList, 'slopeIndex':[] ,'segmodel':segFlag,'segRegionCnt':0 }
|
||||
p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos,ovlap_thres=SecNms)
|
||||
#print('###line274:',p_result[2])
|
||||
#p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,object_config=allowedList,segmodel=segFlag,font=font,padInfos=padInfos)
|
||||
time_info = 'letterbox:%.1f, infer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 )
|
||||
return p_result,time_info+timeOut
|
||||
|
||||
|
||||
def AI_det_track( im0s_in,modelPar,processPar,sort_tracker,segPar=None):
|
||||
im0s,iframe=im0s_in[0],im0s_in[1]
|
||||
model = modelPar['det_Model']
|
||||
segmodel = modelPar['seg_Model']
|
||||
half,device,conf_thres, iou_thres,trtFlag_det = processPar['half'], processPar['device'], processPar['conf_thres'], processPar['iou_thres'],processPar['trtFlag_det']
|
||||
if 'score_byClass' in processPar.keys(): score_byClass = processPar['score_byClass']
|
||||
else: score_byClass = None
|
||||
|
||||
iou2nd = processPar['iou2nd']
|
||||
time0=time.time()
|
||||
|
||||
if trtFlag_det:
|
||||
img, padInfos = imgHelper.img_pad(im0s[0], size=(640,640,3))
|
||||
img = [img]
|
||||
else:
|
||||
img = [imgHelper.letterbox(x, 640, auto=True, stride=32)[0] for x in im0s];padInfos=None
|
||||
img = np.stack(img, 0)
|
||||
# Convert
|
||||
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
img = torch.from_numpy(img).to(device)
|
||||
img = img.half() if half else img.float() # uint8 to fp16/32
|
||||
|
||||
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
|
||||
seg_pred = None;segFlag=False
|
||||
time1=time.time()
|
||||
pred = yolov5Trtforward(model,img) if trtFlag_det else model(img,augment=False)[0]
|
||||
|
||||
time2=time.time()
|
||||
|
||||
#p_result,timeOut = getDetections(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,10,ObjectPar=ObjectPar,font=font,padInfos=padInfos)
|
||||
p_result, timeOut = getDetectionsFromPreds(pred,img,im0s[0],conf_thres=conf_thres,iou_thres=iou_thres,ovlap_thres=iou2nd,padInfos=padInfos)
|
||||
if score_byClass:
|
||||
p_result[2] = score_filter_byClass(p_result[2],score_byClass)
|
||||
if segmodel:
|
||||
seg_pred,segstr = segmodel.eval(im0s[0] )
|
||||
segFlag=True
|
||||
else:
|
||||
seg_pred = None;segFlag=False;segstr='No segmodel'
|
||||
|
||||
|
||||
if segPar and segPar['mixFunction']['function']:
|
||||
mixFunction = segPar['mixFunction']['function']
|
||||
|
||||
H,W = im0s[0].shape[0:2]
|
||||
parMix = segPar['mixFunction']['pars'];#print('###line117:',parMix,p_result[2])
|
||||
parMix['imgSize'] = (W,H)
|
||||
|
||||
|
||||
p_result[2],timeInfos_post = mixFunction(p_result[2], seg_pred, pars=parMix )
|
||||
timeInfos_seg_post = 'segInfer:%s ,postMixProcess:%s'%( segstr, timeInfos_post )
|
||||
else:
|
||||
timeInfos_seg_post = ' '
|
||||
'''
|
||||
if segmodel:
|
||||
timeS1=time.time()
|
||||
#seg_pred,segstr = segtrtEval(segmodel,im0s[0],par=segPar) if segPar['trtFlag_seg'] else segmodel.eval(im0s[0] )
|
||||
seg_pred,segstr = segmodel.eval(im0s[0] )
|
||||
timeS2=time.time()
|
||||
mixFunction = segPar['mixFunction']['function']
|
||||
|
||||
p_result[2],timeInfos_post = mixFunction(p_result[2], seg_pred, pars=segPar['mixFunction']['pars'] )
|
||||
|
||||
timeInfos_seg_post = 'segInfer:%.1f ,postProcess:%s'%( (timeS2-timeS1)*1000, timeInfos_post )
|
||||
|
||||
else:
|
||||
timeInfos_seg_post = ' '
|
||||
#print('######line341:',seg_pred.shape,np.max(seg_pred),np.min(seg_pred) , len(p_result[2]) )
|
||||
'''
|
||||
time_info = 'letterbox:%.1f, detinfer:%.1f, '%( (time1-time0)*1000,(time2-time1)*1000 )
|
||||
|
||||
if sort_tracker:
|
||||
#在这里增加设置调用追踪器的频率
|
||||
#..................USE TRACK FUNCTION....................
|
||||
#pass an empty array to sort
|
||||
dets_to_sort = np.empty((0,7), dtype=np.float32)
|
||||
|
||||
# NOTE: We send in detected object class too
|
||||
#for detclass,x1,y1,x2,y2,conf in p_result[2]:
|
||||
for x1,y1,x2,y2,conf, detclass in p_result[2]:
|
||||
#print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
|
||||
dets_to_sort = np.vstack((dets_to_sort,
|
||||
np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
|
||||
|
||||
# Run SORT
|
||||
tracked_dets = deepcopy(sort_tracker.update(dets_to_sort) )
|
||||
tracks =sort_tracker.getTrackers()
|
||||
p_result.append(tracked_dets) ###index=4
|
||||
p_result.append(tracks) ###index=5
|
||||
|
||||
return p_result,time_info+timeOut+timeInfos_seg_post
|
||||
def AI_det_track_batch(imgarray_list, iframe_list ,modelPar,processPar,sort_tracker,trackPar,segPar=None):
|
||||
'''
|
||||
输入:
|
||||
imgarray_list--图像列表
|
||||
iframe_list -- 帧号列表
|
||||
modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
|
||||
processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
|
||||
sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
|
||||
trackPar--跟踪参数,关键字包括:det_cnt,windowsize
|
||||
segPar--None,分割模型相关参数。如果用不到,则为None
|
||||
输入:retResults,timeInfos
|
||||
retResults:list
|
||||
retResults[0]--imgarray_list
|
||||
retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
|
||||
retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ x0 ,y0 ,x1 ,y1 ,conf, cls ,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。2023.08.03,修改输出格式
|
||||
'''
|
||||
|
||||
det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
|
||||
trackers_dic={}
|
||||
index_list = list(range( 0, len(iframe_list) ,det_cnt ));
|
||||
if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
|
||||
index_list.append( len(iframe_list) - 1 )
|
||||
|
||||
if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
|
||||
retResults = []
|
||||
p_result,timeOut = AI_det_track( [ [imgarray_list[0]] ,iframe_list[0] ],modelPar,processPar,None,segPar )
|
||||
##下面4行内容只是为了保持格式一致
|
||||
detArray = np.array(p_result[2])
|
||||
#print('##line371:',detArray)
|
||||
if len(p_result[2])==0:res=[]
|
||||
else:
|
||||
cnt = detArray.shape[0];trackIds=np.zeros((cnt,1));iframes = np.zeros((cnt,1)) + iframe_list[0]
|
||||
|
||||
#detArray = np.hstack( (detArray[:,1:5], detArray[:,5:6] ,detArray[:,0:1],iframes, trackIds ) )
|
||||
detArray = np.hstack( (detArray[:,0:4], detArray[:,4:6] ,iframes, trackIds ) ) ##2023.08.03 修改输入格式
|
||||
res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in detArray ]
|
||||
retResults=[imgarray_list,detArray,res ]
|
||||
#print('##line380:',retResults[2])
|
||||
return retResults,timeOut
|
||||
|
||||
else:
|
||||
t0 = time.time()
|
||||
timeInfos_track=''
|
||||
for iframe_index, index_frame in enumerate(index_list):
|
||||
p_result,timeOut = AI_det_track( [ [imgarray_list[index_frame]] ,iframe_list[index_frame] ],modelPar,processPar,sort_tracker,segPar )
|
||||
timeInfos_track='%s:%s'%(timeInfos_track,timeOut)
|
||||
|
||||
for tracker in p_result[5]:
|
||||
trackers_dic[tracker.id]=deepcopy(tracker)
|
||||
t1 = time.time()
|
||||
|
||||
track_det_result = np.empty((0,8))
|
||||
for trackId in trackers_dic.keys():
|
||||
tracker = trackers_dic[trackId]
|
||||
bbox_history = np.array(tracker.bbox_history)
|
||||
if len(bbox_history)<2: continue
|
||||
###把(x0,y0,x1,y1)转换成(xc,yc,w,h)
|
||||
xcs_ycs = (bbox_history[:,0:2] + bbox_history[:,2:4] )/2
|
||||
whs = bbox_history[:,2:4] - bbox_history[:,0:2]
|
||||
bbox_history[:,0:2] = xcs_ycs;bbox_history[:,2:4] = whs;
|
||||
|
||||
arrays_box = bbox_history[:,0:7].transpose();frames=bbox_history[:,6]
|
||||
#frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
|
||||
#frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
|
||||
|
||||
##如果要最好化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
|
||||
#inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
|
||||
|
||||
##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
|
||||
inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
|
||||
new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
|
||||
f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
|
||||
move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
|
||||
for im in range(4):
|
||||
interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
|
||||
|
||||
cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
|
||||
interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
|
||||
track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
|
||||
#print('#####line116:',trackId,frame_min,frame_max,'----------',interpolation_x0s.shape,track_det_result.shape ,'-----')
|
||||
|
||||
##将[xc,yc,w,h]转为[x0,y0,x1,y1]
|
||||
x0s = track_det_result[:,0] - track_det_result[:,2]/2 ; x1s = track_det_result[:,0] + track_det_result[:,2]/2
|
||||
y0s = track_det_result[:,1] - track_det_result[:,3]/2 ; y1s = track_det_result[:,1] + track_det_result[:,3]/2
|
||||
track_det_result[:,0] = x0s; track_det_result[:,1] = y0s;
|
||||
track_det_result[:,2] = x1s; track_det_result[:,3] = y1s;
|
||||
detResults=[]
|
||||
for iiframe in iframe_list:
|
||||
boxes_oneFrame = track_det_result[ track_det_result[:,6]==iiframe ]
|
||||
res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in boxes_oneFrame ]
|
||||
#[ x0 ,y0 ,x1 ,y1 ,conf,cls,ifrmae,trackId ]
|
||||
#[ifrmae, x0 ,y0 ,x1 ,y1 ,conf,cls,trackId ]
|
||||
detResults.append( res )
|
||||
|
||||
|
||||
retResults=[imgarray_list,track_det_result,detResults ]
|
||||
t2 = time.time()
|
||||
timeInfos = 'detTrack:%.1f TrackPost:%.1f, %s'%( \
|
||||
timeHelper.deltaTime_MS(t1,t0), \
|
||||
timeHelper.deltaTime_MS(t2,t1), \
|
||||
timeInfos_track )
|
||||
return retResults,timeInfos
|
||||
def AI_det_track_N( im0s_in,modelList,postProcess,sort_tracker):
|
||||
im0s,iframe=im0s_in[0],im0s_in[1]
|
||||
dets = AI_process_N(im0s,modelList,postProcess)
|
||||
p_result=[[],[],dets[0],[] ]
|
||||
if sort_tracker:
|
||||
#在这里增加设置调用追踪器的频率
|
||||
#..................USE TRACK FUNCTION....................
|
||||
#pass an empty array to sort
|
||||
dets_to_sort = np.empty((0,7), dtype=np.float32)
|
||||
|
||||
# NOTE: We send in detected object class too
|
||||
#for detclass,x1,y1,x2,y2,conf in p_result[2]:
|
||||
for x1,y1,x2,y2,conf, detclass in p_result[2]:
|
||||
#print('#######line342:',x1,y1,x2,y2,img.shape,[x1, y1, x2, y2, conf, detclass,iframe])
|
||||
dets_to_sort = np.vstack((dets_to_sort,
|
||||
np.array([x1, y1, x2, y2, conf, detclass,iframe],dtype=np.float32) ))
|
||||
|
||||
# Run SORT
|
||||
tracked_dets = deepcopy(sort_tracker.update(dets_to_sort) )
|
||||
tracks =sort_tracker.getTrackers()
|
||||
p_result.append(tracked_dets) ###index=4
|
||||
p_result.append(tracks) ###index=5
|
||||
|
||||
return p_result,dets[1]
|
||||
def get_tracker_cls(boxes,scId=4,clsId=5):
|
||||
#正常来说一各跟踪链上是一个类别,但是有时目标框检测错误,导致有的跟踪链上有多个类别
|
||||
#为此,根据跟踪链上每一个类别对应的所有框的置信度之和,作为这个跟踪链上目标的类别
|
||||
#输入boxes--跟踪是保留的box_history,[[xc,yc,width,height,score,class,iframe],[...],[...]]
|
||||
## scId=4,score所在的序号; clsId=5;类别所在的序号
|
||||
#输出类别
|
||||
##这个跟踪链上目标的类别
|
||||
ids = list(set(boxes[:,clsId].tolist()))
|
||||
scores = [np.sum( boxes[:,scId] [ boxes[:,clsId]==x ] ) for x in ids]
|
||||
maxScoreId = scores.index(np.max(scores))
|
||||
return int(ids[maxScoreId])
|
||||
|
||||
def AI_det_track_batch_N(imgarray_list, iframe_list ,modelList,postProcess,sort_tracker,trackPar):
|
||||
'''
|
||||
输入:
|
||||
imgarray_list--图像列表
|
||||
iframe_list -- 帧号列表
|
||||
modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
|
||||
processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
|
||||
sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
|
||||
trackPar--跟踪参数,关键字包括:det_cnt,windowsize
|
||||
segPar--None,分割模型相关参数。如果用不到,则为None
|
||||
输入:retResults,timeInfos
|
||||
retResults:list
|
||||
retResults[0]--imgarray_list
|
||||
retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
|
||||
retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ x0 ,y0 ,x1 ,y1 ,conf, cls ,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。2023.08.03,修改输出格式
|
||||
'''
|
||||
|
||||
det_cnt,windowsize = trackPar['det_cnt'] ,trackPar['windowsize']
|
||||
trackers_dic={}
|
||||
index_list = list(range( 0, len(iframe_list) ,det_cnt ));
|
||||
if len(index_list)>1 and index_list[-1]!= iframe_list[-1]:
|
||||
index_list.append( len(iframe_list) - 1 )
|
||||
|
||||
if len(imgarray_list)==1: #如果是单帧图片,则不用跟踪
|
||||
retResults = []
|
||||
p_result,timeOut = AI_det_track_N( [ [imgarray_list[0]] ,iframe_list[0] ],modelList,postProcess,None )
|
||||
##下面4行内容只是为了保持格式一致
|
||||
detArray = np.array(p_result[2])
|
||||
if len(p_result[2])==0:res=[]
|
||||
else:
|
||||
cnt = detArray.shape[0];trackIds=np.zeros((cnt,1));iframes = np.zeros((cnt,1)) + iframe_list[0]
|
||||
|
||||
#detArray = np.hstack( (detArray[:,1:5], detArray[:,5:6] ,detArray[:,0:1],iframes, trackIds ) )
|
||||
detArray = np.hstack( (detArray[:,0:4], detArray[:,4:6] ,iframes, trackIds ) ) ##2023.08.03 修改输入格式
|
||||
res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in detArray ]
|
||||
retResults=[imgarray_list,detArray,res ]
|
||||
#print('##line380:',retResults[2])
|
||||
return retResults,timeOut
|
||||
|
||||
else:
|
||||
t0 = time.time()
|
||||
timeInfos_track=''
|
||||
for iframe_index, index_frame in enumerate(index_list):
|
||||
p_result,timeOut = AI_det_track_N( [ [imgarray_list[index_frame]] ,iframe_list[index_frame] ],modelList,postProcess,sort_tracker )
|
||||
timeInfos_track='%s:%s'%(timeInfos_track,timeOut)
|
||||
|
||||
for tracker in p_result[5]:
|
||||
trackers_dic[tracker.id]=deepcopy(tracker)
|
||||
t1 = time.time()
|
||||
|
||||
track_det_result = np.empty((0,8))
|
||||
for trackId in trackers_dic.keys():
|
||||
tracker = trackers_dic[trackId]
|
||||
bbox_history = np.array(tracker.bbox_history).copy()
|
||||
if len(bbox_history)<2: continue
|
||||
###把(x0,y0,x1,y1)转换成(xc,yc,w,h)
|
||||
xcs_ycs = (bbox_history[:,0:2] + bbox_history[:,2:4] )/2
|
||||
whs = bbox_history[:,2:4] - bbox_history[:,0:2]
|
||||
bbox_history[:,0:2] = xcs_ycs;bbox_history[:,2:4] = whs;
|
||||
|
||||
#2023.11.17添加的。目的是修正跟踪链上所有的框的类别一样
|
||||
chainClsId = get_tracker_cls(bbox_history,scId=4,clsId=5)
|
||||
bbox_history[:,5] = chainClsId
|
||||
|
||||
arrays_box = bbox_history[:,0:7].transpose();frames=bbox_history[:,6]
|
||||
#frame_min--表示该批次图片的起始帧,如该批次是[1,100],则frame_min=1,[101,200]--frame_min=101
|
||||
#frames[0]--表示该目标出现的起始帧,如[1,11,21,31,41],则frames[0]=1,frames[0]可能会在frame_min之前出现,即一个横跨了多个批次。
|
||||
|
||||
##如果要最好化插值范围,则取内区间[frame_min,则frame_max ]和[frames[0],frames[-1] ]的交集
|
||||
#inter_frame_min = int(max(frame_min, frames[0])); inter_frame_max = int(min( frame_max, frames[-1] )) ##
|
||||
|
||||
##如果要求得到完整的目标轨迹,则插值区间要以目标出现的起始点为准
|
||||
inter_frame_min=int(frames[0]);inter_frame_max=int(frames[-1])
|
||||
new_frames= np.linspace(inter_frame_min,inter_frame_max,inter_frame_max-inter_frame_min+1 )
|
||||
f_linear = interpolate.interp1d(frames,arrays_box); interpolation_x0s = (f_linear(new_frames)).transpose()
|
||||
move_cnt_use =(len(interpolation_x0s)+1)//2*2-1 if len(interpolation_x0s)<windowsize else windowsize
|
||||
for im in range(4):
|
||||
interpolation_x0s[:,im] = moving_average_wang(interpolation_x0s[:,im],move_cnt_use )
|
||||
|
||||
cnt = inter_frame_max-inter_frame_min+1; trackIds = np.zeros((cnt,1)) + trackId
|
||||
interpolation_x0s = np.hstack( (interpolation_x0s, trackIds ) )
|
||||
track_det_result = np.vstack(( track_det_result, interpolation_x0s) )
|
||||
#print('#####line116:',trackId,'----------',interpolation_x0s.shape,track_det_result.shape,bbox_history ,'-----')
|
||||
|
||||
##将[xc,yc,w,h]转为[x0,y0,x1,y1]
|
||||
x0s = track_det_result[:,0] - track_det_result[:,2]/2 ; x1s = track_det_result[:,0] + track_det_result[:,2]/2
|
||||
y0s = track_det_result[:,1] - track_det_result[:,3]/2 ; y1s = track_det_result[:,1] + track_det_result[:,3]/2
|
||||
track_det_result[:,0] = x0s; track_det_result[:,1] = y0s;
|
||||
track_det_result[:,2] = x1s; track_det_result[:,3] = y1s;
|
||||
detResults=[]
|
||||
for iiframe in iframe_list:
|
||||
boxes_oneFrame = track_det_result[ track_det_result[:,6]==iiframe ]
|
||||
res = [[ b[0],b[1],b[2],b[3],b[4],b[5],b[6],b[7] ] for b in boxes_oneFrame ]
|
||||
#[ x0 ,y0 ,x1 ,y1 ,conf,cls,ifrmae,trackId ]
|
||||
#[ifrmae, x0 ,y0 ,x1 ,y1 ,conf,cls,trackId ]
|
||||
detResults.append( res )
|
||||
|
||||
|
||||
retResults=[imgarray_list,track_det_result,detResults ]
|
||||
t2 = time.time()
|
||||
timeInfos = 'detTrack:%.1f TrackPost:%.1f, %s'%( \
|
||||
timeHelper.deltaTime_MS(t1,t0), \
|
||||
timeHelper.deltaTime_MS(t2,t1), \
|
||||
timeInfos_track )
|
||||
return retResults,timeInfos
|
||||
|
||||
def ocr_process(pars):
|
||||
|
||||
img_patch,engine,context,converter,AlignCollate_normal,device=pars[0:6]
|
||||
time1 = time.time()
|
||||
img_tensor = AlignCollate_normal([ Image.fromarray(img_patch,'L') ])
|
||||
img_input = img_tensor.to('cuda:0')
|
||||
time2 = time.time()
|
||||
|
||||
preds,trtstr=OcrTrtForward(engine,[img_input],context)
|
||||
time3 = time.time()
|
||||
|
||||
batch_size = preds.size(0)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
|
||||
######## filter ignore_char, rebalance
|
||||
preds_prob = F.softmax(preds, dim=2)
|
||||
preds_prob = preds_prob.cpu().detach().numpy()
|
||||
pred_norm = preds_prob.sum(axis=2)
|
||||
preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
|
||||
preds_prob = torch.from_numpy(preds_prob).float().to(device)
|
||||
_, preds_index = preds_prob.max(2)
|
||||
preds_index = preds_index.view(-1)
|
||||
time4 = time.time()
|
||||
preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data)
|
||||
time5 = time.time()
|
||||
|
||||
info_str= ('pre-process:%.2f TRTforward:%.2f (%s) postProcess:%2.f decoder:%.2f, Total:%.2f , pred:%s'%(\
|
||||
timeHelper.deltaTime_MS(time2,time1 ), \
|
||||
timeHelper.deltaTime_MS(time3,time2 ),trtstr, \
|
||||
timeHelper.deltaTime_MS(time4,time3 ), \
|
||||
timeHelper.deltaTime_MS(time5,time4 ), \
|
||||
timeHelper.deltaTime_MS(time5,time1 ), preds_str ) )
|
||||
return preds_str,info_str
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
# YOLOv5 common modules
|
||||
|
||||
import math
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from torch.cuda import amp
|
||||
|
||||
from DrGraph.util.datasets import letterbox
|
||||
from DrGraph.util.general import non_max_suppression, make_divisible, increment_path, xyxy2xywh
|
||||
from DrGraph.util.plots import color_list, plot_one_box
|
||||
from DrGraph.util.torch_utils import time_synchronized
|
||||
|
||||
from DrGraph.util.drHelper import *
|
||||
|
||||
import warnings
|
||||
|
||||
class SPPF(nn.Module):
|
||||
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
||||
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
||||
super().__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||||
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cv1(x)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
||||
y1 = self.m(x)
|
||||
y2 = self.m(y1)
|
||||
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
|
||||
|
||||
|
||||
def autopad(k, p=None): # kernel, padding
|
||||
# Pad to 'same'
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||
return p
|
||||
|
||||
|
||||
def DWConv(c1, c2, k=1, s=1, act=True):
|
||||
# Depthwise convolution
|
||||
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
|
||||
|
||||
|
||||
class Conv(nn.Module):
|
||||
# Standard convolution
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
||||
self.bn = nn.BatchNorm2d(c2)
|
||||
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def fuseforward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
|
||||
def __init__(self, c, num_heads):
|
||||
super().__init__()
|
||||
self.q = nn.Linear(c, c, bias=False)
|
||||
self.k = nn.Linear(c, c, bias=False)
|
||||
self.v = nn.Linear(c, c, bias=False)
|
||||
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
||||
self.fc1 = nn.Linear(c, c, bias=False)
|
||||
self.fc2 = nn.Linear(c, c, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
||||
x = self.fc2(self.fc1(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
# Vision Transformer https://arxiv.org/abs/2010.11929
|
||||
def __init__(self, c1, c2, num_heads, num_layers):
|
||||
super().__init__()
|
||||
self.conv = None
|
||||
if c1 != c2:
|
||||
self.conv = Conv(c1, c2)
|
||||
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
||||
self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
|
||||
self.c2 = c2
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv is not None:
|
||||
x = self.conv(x)
|
||||
b, _, w, h = x.shape
|
||||
p = x.flatten(2)
|
||||
p = p.unsqueeze(0)
|
||||
p = p.transpose(0, 3)
|
||||
p = p.squeeze(3)
|
||||
e = self.linear(p)
|
||||
x = p + e
|
||||
|
||||
x = self.tr(x)
|
||||
x = x.unsqueeze(3)
|
||||
x = x.transpose(0, 3)
|
||||
x = x.reshape(b, self.c2, w, h)
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Standard bottleneck
|
||||
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
||||
super(Bottleneck, self).__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c_, c2, 3, 1, g=g)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class BottleneckCSP(nn.Module):
|
||||
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super(BottleneckCSP, self).__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
||||
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
||||
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
||||
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
||||
self.act = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
|
||||
|
||||
def forward(self, x):
|
||||
y1 = self.cv3(self.m(self.cv1(x)))
|
||||
y2 = self.cv2(x)
|
||||
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
|
||||
|
||||
|
||||
class C3(nn.Module):
|
||||
# CSP Bottleneck with 3 convolutions
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super(C3, self).__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c1, c_, 1, 1)
|
||||
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
|
||||
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
|
||||
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
|
||||
|
||||
def forward(self, x):
|
||||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
||||
|
||||
|
||||
class C3TR(C3):
|
||||
# C3 module with TransformerBlock()
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e)
|
||||
self.m = TransformerBlock(c_, c_, 4, n)
|
||||
|
||||
|
||||
class SPP(nn.Module):
|
||||
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
||||
def __init__(self, c1, c2, k=(5, 9, 13)):
|
||||
super(SPP, self).__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
||||
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cv1(x)
|
||||
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
||||
|
||||
|
||||
class Focus(nn.Module):
|
||||
# Focus wh information into c-space
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super(Focus, self).__init__()
|
||||
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
|
||||
# self.contract = Contract(gain=2)
|
||||
|
||||
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
||||
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
|
||||
# return self.conv(self.contract(x))
|
||||
|
||||
|
||||
class Contract(nn.Module):
|
||||
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
|
||||
def __init__(self, gain=2):
|
||||
super().__init__()
|
||||
self.gain = gain
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
|
||||
s = self.gain
|
||||
x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
|
||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
|
||||
return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
|
||||
|
||||
|
||||
class Expand(nn.Module):
|
||||
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
|
||||
def __init__(self, gain=2):
|
||||
super().__init__()
|
||||
self.gain = gain
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
|
||||
s = self.gain
|
||||
x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
|
||||
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
|
||||
return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
|
||||
|
||||
|
||||
class Concat(nn.Module):
|
||||
# Concatenate a list of tensors along dimension
|
||||
def __init__(self, dimension=1):
|
||||
super(Concat, self).__init__()
|
||||
self.d = dimension
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class NMS(nn.Module):
|
||||
# Non-Maximum Suppression (NMS) module
|
||||
conf = 0.25 # confidence threshold
|
||||
iou = 0.45 # IoU threshold
|
||||
classes = None # (optional list) filter by class
|
||||
|
||||
def __init__(self):
|
||||
super(NMS, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
|
||||
|
||||
|
||||
class autoShape(nn.Module):
|
||||
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||
conf = 0.25 # NMS confidence threshold
|
||||
iou = 0.45 # NMS IoU threshold
|
||||
classes = None # (optional list) filter by class
|
||||
|
||||
def __init__(self, model):
|
||||
super(autoShape, self).__init__()
|
||||
self.model = model.eval()
|
||||
|
||||
def autoshape(self):
|
||||
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||
# filename: imgs = 'data/images/zidane.jpg'
|
||||
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
|
||||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
|
||||
# numpy: = np.zeros((640,1280,3)) # HWC
|
||||
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
||||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||
|
||||
t = [time_synchronized()]
|
||||
p = next(self.model.parameters()) # for device and type
|
||||
if isinstance(imgs, torch.Tensor): # torch
|
||||
with amp.autocast(enabled=p.device.type != 'cpu'):
|
||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||
|
||||
# Pre-process
|
||||
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
|
||||
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
||||
for i, im in enumerate(imgs):
|
||||
f = f'image{i}' # filename
|
||||
if isinstance(im, str): # filename or uri
|
||||
im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
|
||||
elif isinstance(im, Image.Image): # PIL Image
|
||||
im, f = np.asarray(im), getattr(im, 'filename', f) or f
|
||||
files.append(Path(f).with_suffix('.jpg').name)
|
||||
if im.shape[0] < 5: # image in CHW
|
||||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||
im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
|
||||
s = im.shape[:2] # HWC
|
||||
shape0.append(s) # image shape
|
||||
g = (size / max(s)) # gain
|
||||
shape1.append([y * g for y in s])
|
||||
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
||||
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
|
||||
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
|
||||
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
||||
t.append(time_synchronized())
|
||||
|
||||
with amp.autocast(enabled=p.device.type != 'cpu'):
|
||||
# Inference
|
||||
y = self.model(x, augment, profile)[0] # forward
|
||||
t.append(time_synchronized())
|
||||
|
||||
# Post-process
|
||||
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
||||
for i in range(n):
|
||||
imgHelper.scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||
|
||||
t.append(time_synchronized())
|
||||
return Detections(imgs, y, files, t, self.names, x.shape)
|
||||
|
||||
|
||||
class Detections:
|
||||
# detections class for YOLOv5 inference results
|
||||
def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
|
||||
super(Detections, self).__init__()
|
||||
d = pred[0].device # device
|
||||
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
|
||||
self.imgs = imgs # list of images as numpy arrays
|
||||
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
||||
self.names = names # class names
|
||||
self.files = files # image filenames
|
||||
self.xyxy = pred # xyxy pixels
|
||||
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
||||
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||
self.n = len(self.pred) # number of images (batch size)
|
||||
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
||||
self.s = shape # inference BCHW shape
|
||||
|
||||
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
|
||||
colors = color_list()
|
||||
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
|
||||
str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
|
||||
if pred is not None:
|
||||
for c in pred[:, -1].unique():
|
||||
n = (pred[:, -1] == c).sum() # detections per class
|
||||
str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||
if show or save or render:
|
||||
for *box, conf, cls in pred: # xyxy, confidence, class
|
||||
label = f'{self.names[int(cls)]} {conf:.2f}'
|
||||
plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
|
||||
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
|
||||
if pprint:
|
||||
print(str.rstrip(', '))
|
||||
if show:
|
||||
img.show(self.files[i]) # show
|
||||
if save:
|
||||
f = self.files[i]
|
||||
img.save(Path(save_dir) / f) # save
|
||||
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
|
||||
if render:
|
||||
self.imgs[i] = np.asarray(img)
|
||||
|
||||
def print(self):
|
||||
self.display(pprint=True) # print results
|
||||
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
|
||||
|
||||
def show(self):
|
||||
self.display(show=True) # show results
|
||||
|
||||
def save(self, save_dir='runs/hub/exp'):
|
||||
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.display(save=True, save_dir=save_dir) # save results
|
||||
|
||||
def render(self):
|
||||
self.display(render=True) # render results
|
||||
return self.imgs
|
||||
|
||||
def pandas(self):
|
||||
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||
new = copy(self) # return copy
|
||||
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
||||
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
||||
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
||||
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
||||
return new
|
||||
|
||||
def tolist(self):
|
||||
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
||||
x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
|
||||
for d in x:
|
||||
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
||||
setattr(d, k, getattr(d, k)[0]) # pop out of list
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
return self.n
|
||||
|
||||
|
||||
class Classify(nn.Module):
|
||||
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super(Classify, self).__init__()
|
||||
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
|
||||
self.flat = nn.Flatten()
|
||||
|
||||
def forward(self, x):
|
||||
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
|
||||
return self.flat(self.conv(z)) # flatten to x(b,c2)
|
||||
|
|
@ -0,0 +1,703 @@
|
|||
import json, sys, cv2, os, glob, random, time, datetime
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from DrGraph.util import torchHelper
|
||||
|
||||
class ServiceException(Exception): # 继承异常类
|
||||
def __init__(self, code, msg, desc=None):
|
||||
self.code = code
|
||||
if desc is None:
|
||||
self.msg = msg
|
||||
else:
|
||||
self.msg = msg % desc
|
||||
|
||||
def __str__(self):
|
||||
logger.error("异常编码:{}, 异常描述:{}", self.code, self.msg)
|
||||
|
||||
|
||||
class mathHelper:
|
||||
@staticmethod
|
||||
def init_seeds(seed=0):
|
||||
"""
|
||||
初始化随机数生成器种子
|
||||
|
||||
该函数用于设置随机数生成器的种子,以确保实验的可重现性。
|
||||
它会同时设置Python标准库、NumPy和PyTorch的随机种子。
|
||||
|
||||
参数:
|
||||
seed (int): 随机种子值,默认为0
|
||||
|
||||
返回值:
|
||||
无返回值
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torchHelper.init_torch_seeds(seed)
|
||||
|
||||
@staticmethod
|
||||
def center_coordinate(boundbxs):
|
||||
'''
|
||||
计算矩形框的中心坐标
|
||||
|
||||
参数:
|
||||
boundbxs (list/tuple): 包含两个对角坐标[x1, y1, x2, y2]的序列
|
||||
|
||||
返回:
|
||||
tuple: 矩形框中心坐标(center_x, center_y)
|
||||
'''
|
||||
boundbxs_x1=boundbxs[0]
|
||||
boundbxs_y1=boundbxs[1]
|
||||
boundbxs_x2=boundbxs[2]
|
||||
boundbxs_y2=boundbxs[3]
|
||||
center_x=0.5*(boundbxs_x1+boundbxs_x2)
|
||||
center_y=0.5*(boundbxs_y1+boundbxs_y2)
|
||||
return center_x,center_y
|
||||
|
||||
@staticmethod
|
||||
def fourcorner_coordinate(boundbxs):
|
||||
'''
|
||||
根据矩形框的两个对角坐标计算四个角点坐标
|
||||
|
||||
参数:
|
||||
boundbxs (list): 包含两个对角坐标[x1, y1, x2, y2]的列表
|
||||
|
||||
返回:
|
||||
list: 矩形框四个角点坐标,按contours顺序排列的二维列表
|
||||
[[x1,y1], [x3,y3], [x2,y2], [x4,y4]]
|
||||
'''
|
||||
boundbxs_x1=boundbxs[0]
|
||||
boundbxs_y1=boundbxs[1]
|
||||
boundbxs_x2=boundbxs[2]
|
||||
boundbxs_y2=boundbxs[3]
|
||||
wid=boundbxs_x2-boundbxs_x1
|
||||
hei=boundbxs_y2-boundbxs_y1
|
||||
boundbxs_x3=boundbxs_x1+wid
|
||||
boundbxs_y3=boundbxs_y1
|
||||
boundbxs_x4=boundbxs_x1
|
||||
boundbxs_y4 = boundbxs_y1+hei
|
||||
contours_rec=[[boundbxs_x1,boundbxs_y1],[boundbxs_x3,boundbxs_y3],[boundbxs_x2,boundbxs_y2],[boundbxs_x4,boundbxs_y4]]
|
||||
return contours_rec
|
||||
|
||||
@staticmethod
|
||||
def xywh2xyxy(box,iW=None,iH=None):
|
||||
"""
|
||||
将边界框从中心点+宽高格式(xc,yc,w,h)转换为左上角+右下角格式(x0,y0,x1,y1)
|
||||
|
||||
参数:
|
||||
box: 包含边界框坐标信息的列表或数组,前4个元素为[xc,yc,w,h]
|
||||
iW: 图像宽度,用于将归一化坐标转换为像素坐标
|
||||
iH: 图像高度,用于将归一化坐标转换为像素坐标
|
||||
|
||||
返回:
|
||||
list: 转换后的边界框坐标[x0,y0,x1,y1]
|
||||
"""
|
||||
xc,yc,w,h = box[0:4]
|
||||
x0 =max(0, xc-w/2.0)
|
||||
x1 =min(1, xc+w/2.0)
|
||||
y0=max(0, yc-h/2.0)
|
||||
y1=min(1,yc+h/2.0)
|
||||
if iW: x0,x1 = x0*iW,x1*iW
|
||||
if iH: y0,y1 = y0*iH,y1*iH
|
||||
return [x0,y0,x1,y1]
|
||||
|
||||
class ioHelper: # IO工具类
|
||||
@staticmethod
|
||||
def get_labelnames(labelnames):
|
||||
"""
|
||||
从JSON文件中读取标签名称列表
|
||||
|
||||
参数:
|
||||
labelnames (str): 包含标签名称的JSON文件路径
|
||||
|
||||
返回:
|
||||
list: 从JSON文件中读取的标签名称列表
|
||||
"""
|
||||
with open(labelnames,'r') as fp:
|
||||
namesjson=json.load(fp)
|
||||
names_fromfile=namesjson['labelnames']
|
||||
names = names_fromfile
|
||||
return names
|
||||
|
||||
@staticmethod
|
||||
def get_images_videos(destPath, imageFixs=['.jpg','.JPG','.PNG','.png'],videoFixs=['.MP4','.mp4','.avi']):
|
||||
'''
|
||||
获取指定路径下的所有图像和视频文件路径
|
||||
|
||||
参数:
|
||||
destPath (str): 输入路径,可以是文件夹路径或单个文件路径
|
||||
imageFixs (list): 图像文件后缀名列表,默认为['.jpg','.JPG','.PNG','.png']
|
||||
videoFixs (list): 视频文件后缀名列表,默认为['.MP4','.mp4','.avi']
|
||||
|
||||
返回:
|
||||
tuple: 包含两个列表的元组 (imgpaths, videopaths)
|
||||
- imgpaths (list): 图像文件路径列表
|
||||
- videopaths (list): 视频文件路径列表
|
||||
'''
|
||||
imageFileNames = [];###获取文件里所有的图像
|
||||
videoFileNames = []###获取文件里所有的视频
|
||||
if os.path.isdir(destPath):
|
||||
for postfix in imageFixs:
|
||||
imageFileNames . extend(glob.glob('%s/*%s'%(destPath,postfix )) )
|
||||
for postfix in videoFixs:
|
||||
videoFileNames.extend(glob.glob('%s/*%s'%(destPath,postfix )) )
|
||||
else:
|
||||
postfix = os.path.splitext(destPath)[-1]
|
||||
if postfix in imageFixs: imageFileNames = [ destPath ]
|
||||
if postfix in videoFixs: videoFileNames = [destPath ]
|
||||
|
||||
logger.info('目录 [%s]下多媒体文件数量: Images %d, videos %d' % (destPath, len(imageFileNames), len(videoFileNames)))
|
||||
return imageFileNames , videoFileNames
|
||||
|
||||
@staticmethod
|
||||
def get_postProcess_para(parfile):
|
||||
"""
|
||||
从参数文件中读取后处理参数
|
||||
|
||||
参数:
|
||||
parfile (str): 包含后处理参数的JSON格式参数文件路径
|
||||
|
||||
返回:
|
||||
tuple: 包含四个元素的元组:
|
||||
- conf_thres (float): 置信度阈值
|
||||
- iou_thres (float): IOU阈值
|
||||
- classes (list): 类别列表
|
||||
- rainbows (list): 彩虹显示参数列表
|
||||
|
||||
异常:
|
||||
AssertionError: 当参数文件中不包含'post_process'关键字时抛出
|
||||
"""
|
||||
with open(parfile) as fp:
|
||||
par = json.load(fp)
|
||||
assert 'post_process' in par.keys(), ' parfile has not key word:post_process'
|
||||
parPost=par['post_process']
|
||||
|
||||
return parPost["conf_thres"],parPost["iou_thres"],parPost["classes"],parPost["rainbows"]
|
||||
|
||||
@staticmethod
|
||||
def get_postProcess_para_dic(parfile):
|
||||
"""
|
||||
从参数文件中读取后处理参数字典
|
||||
|
||||
参数:
|
||||
parfile (str): 参数文件路径,文件应包含JSON格式的配置数据
|
||||
|
||||
返回:
|
||||
dict: 包含后处理参数的字典,从参数文件的'post_process'键中提取
|
||||
"""
|
||||
with open(parfile) as fp:
|
||||
par = json.load(fp)
|
||||
parPost=par['post_process']
|
||||
return parPost
|
||||
@staticmethod
|
||||
def checkFile(fileName, desc):
|
||||
"""
|
||||
检查文件是否存在
|
||||
|
||||
参数:
|
||||
fileName (str): 要检查的文件名
|
||||
desc (str): 文件描述信息,用于日志输出
|
||||
|
||||
返回值:
|
||||
bool: 文件存在返回True,不存在返回False
|
||||
"""
|
||||
if (len(fileName) > 0) and (os.path.exists(fileName) is False):
|
||||
logger.error(f"{desc} - {fileName} 不存在,请检查!")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"{desc} - {fileName} 存在")
|
||||
return True
|
||||
|
||||
class timeHelper:
|
||||
@staticmethod
|
||||
def date_modified(path=__file__):
|
||||
# 将文件的修改时间戳转换为datetime对象
|
||||
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
# 返回格式化的日期字符串
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
|
||||
@staticmethod
|
||||
def deltaTimeString_MS(t2, t1, floatNumber = 1):
|
||||
"""
|
||||
计算两个时间点之间的差值,并以毫秒为单位返回格式化字符串
|
||||
|
||||
参数:
|
||||
t2: 结束时间点(秒)
|
||||
t1: 开始时间点(秒)
|
||||
|
||||
返回值:
|
||||
str: 两个时间点差值的毫秒表示,保留一位小数
|
||||
"""
|
||||
formatString = '0.%df' % floatNumber
|
||||
return ('%' + formatString) % ( (t2 - t1)*1000.0)
|
||||
|
||||
@staticmethod
|
||||
def deltaTime_MS(t2,t1): # get_ms
|
||||
"""
|
||||
计算两个时间点之间的差值,并转换为毫秒单位
|
||||
|
||||
参数:
|
||||
t2: 结束时间点
|
||||
t1: 起始时间点
|
||||
|
||||
返回值:
|
||||
float: 两个时间点差值的毫秒表示
|
||||
"""
|
||||
return (t2-t1)*1000.0
|
||||
|
||||
class drawHelper:
|
||||
|
||||
@staticmethod
|
||||
def drawAllBox(preds,imgDraw,label_arraylist,rainbows,font):
|
||||
"""
|
||||
在图像上绘制所有检测框和标签信息
|
||||
|
||||
参数:
|
||||
preds: 检测结果列表,每个元素包含检测框坐标、置信度和类别信息
|
||||
imgDraw: 用于绘制的图像对象
|
||||
label_arraylist: 标签名称列表
|
||||
rainbows: 颜色列表,用于不同类别的框绘制
|
||||
font: 绘制文本使用的字体
|
||||
|
||||
返回值:
|
||||
imgDraw: 绘制完成的图像对象
|
||||
"""
|
||||
for box in preds:
|
||||
cls,conf,xyxy = box[5],box[4], box[0:4] ##2023.08.03,修改了格式
|
||||
#print('#####line46 demo.py:', cls,conf,xyxy, len(label_arraylist),len(rainbows) )
|
||||
imgDraw = drawHelper.draw_painting_joint(xyxy,imgDraw,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],font=font,socre_location="leftTop")
|
||||
return imgDraw
|
||||
|
||||
@staticmethod
|
||||
def draw_painting_joint(box,img,label_array,score=0.5,color=None,font={ 'line_thickness':None,'boxLine_thickness':None, 'fontSize':None},socre_location="leftTop"):
|
||||
"""
|
||||
在图像上绘制检测框、类别标签和置信度分数。
|
||||
|
||||
参数:
|
||||
box (list or tuple): 检测框坐标,支持两种格式:
|
||||
- 四点格式: [(x0,y0),(x1,y1),(x2,y2),(x3,y3)]
|
||||
- 两点格式: [x0,y0,x1,y1]
|
||||
img (numpy.ndarray): 输入图像,形状为 (H, W, C)
|
||||
label_array (numpy.ndarray): 类别标签图像,形状为 (H, W, C)
|
||||
score (float): 检测置信度分数,默认为0.5
|
||||
color (tuple or list): 检测框颜色,格式为(B, G, R)
|
||||
font (dict): 字体相关参数字典,包含以下键值:
|
||||
- 'line_thickness': 文本线条粗细
|
||||
- 'boxLine_thickness': 检测框线条粗细
|
||||
- 'fontSize': 字体大小
|
||||
socre_location (str): 分数显示位置,支持'leftTop'和'leftBottom',默认为'leftTop'
|
||||
|
||||
返回:
|
||||
numpy.ndarray: 绘制完成的图像
|
||||
"""
|
||||
#如果box[0]不是list or 元组,则box是[ (x0,y0),(x1,y1),(x2,y2),(x3,y3)]四点格式
|
||||
if isinstance(box[0], (list, tuple,np.ndarray ) ):
|
||||
###先把中文类别字体赋值到img中
|
||||
lh, lw, lc = label_array.shape
|
||||
imh, imw, imc = img.shape
|
||||
if socre_location=='leftTop':
|
||||
x0 , y1 = box[0][0],box[0][1]
|
||||
elif socre_location=='leftBottom':
|
||||
x0,y1=box[3][0],box[3][1]
|
||||
else:
|
||||
print('plot.py line217 ,label_location:%s not implemented '%( socre_location ))
|
||||
sys.exit(0)
|
||||
|
||||
x1 , y0 = x0 + lw , y1 - lh
|
||||
if y0<0:y0=0;y1=y0+lh
|
||||
if y1>imh: y1=imh;y0=y1-lh
|
||||
if x0<0:x0=0;x1=x0+lw
|
||||
if x1>imw:x1=imw;x0=x1-lw
|
||||
img[y0:y1,x0:x1,:] = label_array
|
||||
pts_cls=[(x0,y0),(x1,y1) ]
|
||||
|
||||
#把四边形的框画上
|
||||
box_tl= font['boxLine_thickness'] or round(0.002 * (imh + imw) / 2) + 1
|
||||
cv2.polylines(img, [box], True,color , box_tl)
|
||||
|
||||
####把英文字符score画到类别旁边
|
||||
tl = font['line_thickness'] or round(0.002*(imh+imw)/2)+1#line/font thickness
|
||||
label = ' %.2f'%(score)
|
||||
tf = max(tl , 1) # font thickness
|
||||
fontScale = font['fontSize'] or tl * 0.33
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
|
||||
|
||||
|
||||
#if socre_location=='leftTop':
|
||||
p1,p2= (pts_cls[1][0], pts_cls[0][1]),(pts_cls[1][0]+t_size[0],pts_cls[1][1])
|
||||
cv2.rectangle(img, p1 , p2, color, -1, cv2.LINE_AA)
|
||||
p3 = pts_cls[1][0],pts_cls[1][1]-(lh-t_size[1])//2
|
||||
|
||||
cv2.putText(img, label,p3, 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
return img
|
||||
else:####两点格式[x0,y0,x1,y1]
|
||||
try:
|
||||
box = [int(xx.cpu()) for xx in box]
|
||||
except:
|
||||
box=[ int(x) for x in box]
|
||||
###先把中文类别字体赋值到img中
|
||||
lh, lw, lc = label_array.shape
|
||||
imh, imw, imc = img.shape
|
||||
if socre_location=='leftTop':
|
||||
x0 , y1 = box[0:2]
|
||||
elif socre_location=='leftBottom':
|
||||
x0,y1=box[0],box[3]
|
||||
else:
|
||||
print('plot.py line217 ,socre_location:%s not implemented '%( socre_location ))
|
||||
sys.exit(0)
|
||||
x1 , y0 = x0 + lw , y1 - lh
|
||||
if y0<0:y0=0;y1=y0+lh
|
||||
if y1>imh: y1=imh;y0=y1-lh
|
||||
if x0<0:x0=0;x1=x0+lw
|
||||
if x1>imw:x1=imw;x0=x1-lw
|
||||
img[y0:y1,x0:x1,:] = label_array
|
||||
|
||||
###把矩形框画上,指定颜色和线宽
|
||||
tl = font['line_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
box_tl= font['boxLine_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
|
||||
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(img, c1, c2, color, thickness=box_tl, lineType=cv2.LINE_AA)
|
||||
|
||||
###把英文字符score画到类别旁边
|
||||
label = ' %.2f'%(score)
|
||||
tf = max(tl , 1) # font thickness
|
||||
fontScale = font['fontSize'] or tl * 0.33
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
|
||||
|
||||
if socre_location=='leftTop':
|
||||
c2 = c1[0]+ lw + t_size[0], c1[1] - lh
|
||||
cv2.rectangle(img, (int(box[0])+lw,int(box[1])) , c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(img, label, (c1[0]+lw, c1[1] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
elif socre_location=='leftBottom':
|
||||
c2 = box[0]+ lw + t_size[0], box[3] - lh
|
||||
cv2.rectangle(img, (int(box[0])+lw,int(box[3])) , c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(img, label, ( box[0] + lw, box[3] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
#print('#####line224 fontScale:',fontScale,' thickness:',tf,' line_thickness:',font['line_thickness'],' boxLine thickness:',box_tl)
|
||||
return img
|
||||
|
||||
class imgHelper:
|
||||
@staticmethod
|
||||
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
||||
"""
|
||||
对图像进行等比例缩放并填充(letterbox)操作,使其满足指定尺寸要求,并保持宽高比。
|
||||
|
||||
参数:
|
||||
img (numpy.ndarray): 输入图像,形状为 [H, W, C]。
|
||||
new_shape (int or tuple): 目标图像的尺寸。若为整数,则表示目标图像的宽和高相同;
|
||||
若为元组,则格式为 (height, width)。
|
||||
color (tuple): 填充边框的颜色,格式为 (B, G, R)。
|
||||
auto (bool): 是否自动调整填充大小以满足 stride 的倍数约束。
|
||||
scaleFill (bool): 是否拉伸图像以完全填充目标尺寸(不保持宽高比)。
|
||||
scaleup (bool): 是否允许放大图像。若为 False,则只缩小图像。
|
||||
stride (int): 步长,用于确保输出图像尺寸是该值的倍数。
|
||||
|
||||
返回:
|
||||
img (numpy.ndarray): 处理后的图像。
|
||||
ratio (tuple): 宽度和高度的缩放比例 (width_ratio, height_ratio)。
|
||||
(dw, dh) (tuple): 图像左右和上下方向的填充像素数量。
|
||||
"""
|
||||
# 获取当前图像的高度和宽度
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# 计算缩放比例,取宽高方向上的最小缩放比例以保持图像不被裁剪
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not scaleup: # only scale down, do not scale up (for better test mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# 根据缩放比例计算新的未填充尺寸
|
||||
ratio = r, r # width, height ratios
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
|
||||
# 根据不同模式调整填充尺寸
|
||||
if auto: # 最小矩形填充,使图像尺寸满足 stride 的倍数
|
||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
||||
elif scaleFill: # 拉伸图像以填满整个区域(不保持宽高比)
|
||||
dw, dh = 0.0, 0.0
|
||||
new_unpad = (new_shape[1], new_shape[0])
|
||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
||||
|
||||
# 将总填充量分配到图像两侧
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
# 如果需要缩放图像,则执行 resize 操作
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 计算上下左右的填充像素数
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
|
||||
# 在图像周围添加填充边框
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
return img, ratio, (dw, dh)
|
||||
|
||||
@staticmethod
|
||||
def img_pad(img, size, pad_value=[114,114,114]):
|
||||
"""
|
||||
将图像填充成固定尺寸
|
||||
|
||||
参数:
|
||||
img: 输入图像,numpy数组格式
|
||||
size: 目标尺寸,格式为(高, 宽)
|
||||
pad_value: 填充区域的颜色值,默认为[114,114,114]
|
||||
|
||||
返回值:
|
||||
pad_image: 填充后的图像
|
||||
(top, left, r): 填充信息元组,包含上边填充像素数、左边填充像素数和缩放比例
|
||||
"""
|
||||
H,W,_ = img.shape
|
||||
r = max(H/size[0], W/size[1])
|
||||
img_r = cv2.resize(img, (int(W/r), int(H/r)))
|
||||
tb = size[0] - img_r.shape[0]
|
||||
lr = size[1] - img_r.shape[1]
|
||||
top = int(tb/2)
|
||||
bottom = tb - top
|
||||
left = int(lr/2)
|
||||
right = lr - left
|
||||
pad_image = cv2.copyMakeBorder(img_r, top, bottom, left, right, cv2.BORDER_CONSTANT,value=pad_value)
|
||||
return pad_image,(top, left,r)
|
||||
|
||||
@staticmethod
|
||||
def get_label_array( color=None, label=None,outfontsize=None,fontpath="./DrGraph/appIOs/conf/platech.ttf"):
|
||||
"""
|
||||
创建一个包含指定标签文本的图像数组
|
||||
|
||||
参数:
|
||||
color: 标签背景颜色,元组格式(R, G, B)
|
||||
label: 要显示的标签文本
|
||||
outfontsize: 输出字体大小
|
||||
fontpath: 字体文件路径,默认为"conf/platech.ttf"
|
||||
|
||||
返回:
|
||||
numpy数组格式的标签图像
|
||||
"""
|
||||
# Plots one bounding box on image 'im' using PIL
|
||||
fontsize = outfontsize
|
||||
font = ImageFont.truetype(fontpath, fontsize,encoding='utf-8')
|
||||
|
||||
txt_width, txt_height = font.getsize(label)
|
||||
im = np.zeros((txt_height,txt_width,3),dtype=np.uint8)
|
||||
im = Image.fromarray(im)
|
||||
draw = ImageDraw.Draw(im)
|
||||
draw.rectangle([0, 0 , txt_width, txt_height ], fill=tuple(color))
|
||||
draw.text(( 0 , -3 ), label, fill=(255, 255, 255), font=font)
|
||||
im_array = np.asarray(im)
|
||||
|
||||
if outfontsize:
|
||||
scaley = outfontsize / txt_height
|
||||
im_array= cv2.resize(im_array,(0,0),fx = scaley ,fy =scaley)
|
||||
return im_array
|
||||
|
||||
@staticmethod
|
||||
def get_label_arrays(labelnames,colors,outfontsize=40,fontpath="./DrGraph/appIOs/conf/platech.ttf"):
|
||||
"""
|
||||
生成标签数组列表
|
||||
|
||||
该函数根据提供的标签名称和颜色列表,为每个标签创建对应的数组表示
|
||||
|
||||
参数:
|
||||
labelnames (list): 标签名称列表
|
||||
colors (list): 颜色列表,用于为标签着色
|
||||
outfontsize (int): 输出字体大小,默认为40
|
||||
fontpath (str): 字体文件路径,默认为"conf/platech.ttf"
|
||||
|
||||
返回:
|
||||
list: 包含每个标签对应数组的列表
|
||||
"""
|
||||
label_arraylist = []
|
||||
if len(labelnames) > len(colors):
|
||||
print('#####labelnames cnt > colors cnt#####')
|
||||
for ii,labelname in enumerate(labelnames):
|
||||
|
||||
color = colors[ii%20]
|
||||
label_arraylist.append(imgHelper.get_label_array(color=color,label=labelname,outfontsize=outfontsize,fontpath=fontpath))
|
||||
|
||||
return label_arraylist
|
||||
|
||||
|
||||
def clip_coords(boxes, img_shape):
|
||||
"""
|
||||
将边界框坐标裁剪到图像边界范围内
|
||||
|
||||
参数:
|
||||
boxes: torch.Tensor, 形状为(n, 4)的边界框坐标张量,格式为(xyxy)
|
||||
img_shape: tuple or list, 图像形状(height, width)
|
||||
|
||||
返回值:
|
||||
无返回值,直接在原地修改boxes张量
|
||||
"""
|
||||
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
||||
boxes[:, 0].clamp_(0, img_shape[1]) # x1
|
||||
boxes[:, 1].clamp_(0, img_shape[0]) # y1
|
||||
boxes[:, 2].clamp_(0, img_shape[1]) # x2
|
||||
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
||||
|
||||
@staticmethod
|
||||
def scale_back(boxes,padInfos):
|
||||
'''
|
||||
将边界框坐标从填充后的图像空间缩放回原始图像空间
|
||||
|
||||
参数:
|
||||
boxes: numpy数组,形状为(n, 4),表示边界框坐标,格式为[x1, y1, x2, y2]
|
||||
padInfos: 列表或数组,包含填充信息,前三个元素分别为[top, left, scale_ratio]
|
||||
|
||||
返回值:
|
||||
numpy数组,形状为(n, 4),缩放回原始图像空间的边界框坐标
|
||||
'''
|
||||
top, left,r = padInfos[0:3]
|
||||
boxes[:,0] = (boxes[:,0] - left) * r
|
||||
boxes[:,2] = (boxes[:,2] - left) * r
|
||||
boxes[:,1] = (boxes[:,1] - top) * r
|
||||
boxes[:,3] = (boxes[:,3] - top) * r
|
||||
return boxes
|
||||
|
||||
@staticmethod
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
||||
"""
|
||||
将坐标从img1_shape尺寸缩放回img0_shape原始尺寸
|
||||
|
||||
参数:
|
||||
img1_shape: 目标图像尺寸 (height, width)
|
||||
coords: 需要缩放的坐标数组,格式为xyxy [x1, y1, x2, y2]
|
||||
img0_shape: 原始图像尺寸 (height, width)
|
||||
ratio_pad: 缩放比例和填充信息,如果为None则自动计算
|
||||
|
||||
返回:
|
||||
缩放后的坐标数组
|
||||
"""
|
||||
# Rescale coords (xyxy) from img1_shape to img0_shape
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
coords[:, [0, 2]] -= pad[0] # x padding
|
||||
coords[:, [1, 3]] -= pad[1] # y padding
|
||||
coords[:, :4] /= gain
|
||||
imgHelper.clip_coords(coords, img0_shape)
|
||||
return coords
|
||||
|
||||
@staticmethod
|
||||
def expand_rectangle(rec,imgSize,ex_width,ex_height):
|
||||
'''
|
||||
矩形框外扩,且不超过图像范围
|
||||
输入:矩形框xyxy(左上和右下坐标),图像,外扩宽度大小,外扩高度大小
|
||||
输出:扩后的矩形框坐标xyxy
|
||||
|
||||
参数:
|
||||
rec: 列表,包含4个元素的矩形框坐标[x1, y1, x3, y3],其中(x1,y1)为左上角坐标,(x3,y3)为右下角坐标
|
||||
imgSize: 图像尺寸,格式为[width, height]
|
||||
ex_width: int,矩形框在宽度方向的外扩像素大小
|
||||
ex_height: int,矩形框在高度方向的外扩像素大小
|
||||
|
||||
返回值:
|
||||
list,外扩后的矩形框坐标[x1, y1, x3, y3]
|
||||
'''
|
||||
#img_height=img.shape[0];img_width=img.shape[1]
|
||||
img_width,img_height = imgSize[0:2]
|
||||
#print('高、宽',img_height,img_width)
|
||||
x1=rec[0]
|
||||
y1=rec[1]
|
||||
x3=rec[2]
|
||||
y3=rec[3]
|
||||
|
||||
x1=x1-ex_width if x1-ex_width >= 0 else 0
|
||||
y1=y1-ex_height if y1-ex_height >= 0 else 0
|
||||
x3=x3+ex_width if x3+ex_width <= img_width else img_width
|
||||
y3=y3+ex_height if y3+ex_height <=img_height else img_height
|
||||
xyxy=[x1,y1,x3,y3]
|
||||
|
||||
return xyxy
|
||||
|
||||
class TimeDebugger:
|
||||
currentDebugger = None
|
||||
def __init__(self, bussinessName, enabled = True, logAtExit = False):
|
||||
self.enabled = enabled
|
||||
if not enabled:
|
||||
return
|
||||
self.indexInParentStep = 0
|
||||
if TimeDebugger.currentDebugger:
|
||||
self.parent = TimeDebugger.currentDebugger
|
||||
self.level = self.parent.level + 1
|
||||
self.parent.children.append(self)
|
||||
self.indexInParentStep = len(self.parent.stepMoments)
|
||||
else:
|
||||
self.parent = None
|
||||
self.level = 1
|
||||
self.bussinessName = bussinessName
|
||||
self.logAtExit = logAtExit
|
||||
self.withDetail_MS = False
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.stepMoments = []
|
||||
self.startMoment = time.time()
|
||||
self.children = []
|
||||
self.exitMoment = 0
|
||||
def __enter__(self):
|
||||
if self.enabled:
|
||||
TimeDebugger.currentDebugger = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.enabled:
|
||||
TimeDebugger.currentDebugger = self.parent
|
||||
self.exitMoment = time.time()
|
||||
if self.logAtExit:
|
||||
logger.info(self.getReportInfo())
|
||||
|
||||
def addStep(self, msg):
|
||||
if self.enabled:
|
||||
import inspect
|
||||
frame = inspect.currentframe().f_back
|
||||
filename = frame.f_code.co_filename
|
||||
lineno = frame.f_lineno
|
||||
function_name = frame.f_code.co_name
|
||||
callerLocation = f"{os.path.basename(filename)}:{lineno} in {function_name}"
|
||||
|
||||
self.stepMoments.append((time.time(), msg, callerLocation))
|
||||
|
||||
def getStepInfo(self, index, lastTime):
|
||||
(t, msg, callerLocation) = self.stepMoments[index]
|
||||
info = '\n' + '\t' * (self.level + 1) + '%s: %s 毫秒' % (msg, timeHelper.deltaTimeString_MS(t, lastTime))
|
||||
if self.withDetail_MS:
|
||||
info += '(%d ~ %d)' % (lastTime * 10000 % 1000000, t * 10000 % 1000000)
|
||||
l = 3 - (len(info) // 25)
|
||||
info += '\t' * l if l > 0 else '\t'
|
||||
info += callerLocation # + (' %d, %d, %d' % (len(info), 3 - (len(info) // 25), l))
|
||||
return info
|
||||
def getReportInfo(self, desc = ""):
|
||||
t = time.time() if self.exitMoment == 0 else self.exitMoment
|
||||
info = desc if len(desc) > 0 else ""
|
||||
info += '[%s]业务 总共耗时 %s 毫秒,其中:' % (self.bussinessName, timeHelper.deltaTimeString_MS(t, self.startMoment))
|
||||
if self.withDetail_MS:
|
||||
info += '(%d ~ %d)' % (self.startMoment * 10000 % 1000000, t * 10000 % 1000000)
|
||||
lastTime = self.startMoment
|
||||
nextStepIndex = 0
|
||||
for child in self.children:
|
||||
childStepIndex = child.indexInParentStep
|
||||
if childStepIndex == len(self.stepMoments):
|
||||
childStepIndex -= 1
|
||||
for i in range(nextStepIndex, childStepIndex + 1):
|
||||
info += self.getStepInfo(i, lastTime)
|
||||
lastTime = self.stepMoments[i][0]
|
||||
nextStepIndex = i + 1
|
||||
info += ' -> ' + child.getReportInfo()
|
||||
for i in range(nextStepIndex, len(self.stepMoments)):
|
||||
info += self.getStepInfo(i, lastTime)
|
||||
lastTime = self.stepMoments[i][0]
|
||||
# (t, msg) = self.stepMoments[i]
|
||||
# info += '\n' + '\t' * (self.level + 1) + '%s: %s 毫秒' % (msg, timeHelper.deltaTimeString_MS(t, lastTime))
|
||||
# if self.withDetail_MS:
|
||||
# info += '(%d ~ %d)' % (lastTime * 10000 % 1000000, t * 10000 % 1000000)
|
||||
# lastTime = t
|
||||
return info
|
||||
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
# YOLOv5 experimental modules
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
from util.common import Conv, DWConv
|
||||
from util.google_utils import attempt_download
|
||||
|
||||
|
||||
class CrossConv(nn.Module):
|
||||
# Cross Convolution Downsample
|
||||
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
|
||||
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
|
||||
super(CrossConv, self).__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, (1, k), (1, s))
|
||||
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class Sum(nn.Module):
|
||||
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||
def __init__(self, n, weight=False): # n: number of inputs
|
||||
super(Sum, self).__init__()
|
||||
self.weight = weight # apply weights boolean
|
||||
self.iter = range(n - 1) # iter object
|
||||
if weight:
|
||||
self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
|
||||
|
||||
def forward(self, x):
|
||||
y = x[0] # no weight
|
||||
if self.weight:
|
||||
w = torch.sigmoid(self.w) * 2
|
||||
for i in self.iter:
|
||||
y = y + x[i + 1] * w[i]
|
||||
else:
|
||||
for i in self.iter:
|
||||
y = y + x[i + 1]
|
||||
return y
|
||||
|
||||
|
||||
class GhostConv(nn.Module):
|
||||
# Ghost Convolution https://github.com/huawei-noah/ghostnet
|
||||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
||||
super(GhostConv, self).__init__()
|
||||
c_ = c2 // 2 # hidden channels
|
||||
self.cv1 = Conv(c1, c_, k, s, None, g, act)
|
||||
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.cv1(x)
|
||||
return torch.cat([y, self.cv2(y)], 1)
|
||||
|
||||
|
||||
class GhostBottleneck(nn.Module):
|
||||
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
|
||||
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
||||
super(GhostBottleneck, self).__init__()
|
||||
c_ = c2 // 2
|
||||
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
|
||||
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
||||
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
||||
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
|
||||
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x) + self.shortcut(x)
|
||||
|
||||
|
||||
class MixConv2d(nn.Module):
|
||||
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
|
||||
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
|
||||
super(MixConv2d, self).__init__()
|
||||
groups = len(k)
|
||||
if equal_ch: # equal c_ per group
|
||||
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
|
||||
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
|
||||
else: # equal weight.numel() per group
|
||||
b = [c2] + [0] * groups
|
||||
a = np.eye(groups + 1, groups, k=-1)
|
||||
a -= np.roll(a, 1, axis=1)
|
||||
a *= np.array(k) ** 2
|
||||
a[0] = 1
|
||||
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
||||
|
||||
self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
|
||||
self.bn = nn.BatchNorm2d(c2)
|
||||
self.act = nn.LeakyReLU(0.1, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
# Ensemble of models
|
||||
def __init__(self):
|
||||
super(Ensemble, self).__init__()
|
||||
|
||||
def forward(self, x, augment=False):
|
||||
y = []
|
||||
for module in self:
|
||||
y.append(module(x, augment)[0])
|
||||
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||
# y = torch.stack(y).mean(0) # mean ensemble
|
||||
y = torch.cat(y, 1) # nms ensemble
|
||||
return y, None # inference, train output
|
||||
|
||||
|
||||
def attempt_load(weights, map_location=None):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
#attempt_download(w)
|
||||
assert os.path.exists(w),"%s not exists"
|
||||
ckpt = torch.load(w, map_location=map_location) # load
|
||||
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
|
||||
|
||||
# Compatibility updates
|
||||
for m in model.modules():
|
||||
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
||||
m.inplace = True # pytorch 1.7.0 compatibility
|
||||
elif type(m) is Conv:
|
||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||
|
||||
if len(model) == 1:
|
||||
return model[-1] # return model
|
||||
else:
|
||||
print('Ensemble created with %s\n' % weights)
|
||||
for k in ['names', 'stride']:
|
||||
setattr(model, k, getattr(model[-1], k))
|
||||
return model # return ensemble
|
||||
|
|
@ -0,0 +1,614 @@
|
|||
# YOLOv5 general utils
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
import yaml
|
||||
|
||||
from DrGraph.util.google_utils import gsutil_getsize
|
||||
from DrGraph.util.drHelper import *
|
||||
|
||||
# Settings
|
||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
pd.options.display.max_columns = 10
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
||||
|
||||
def fitness(x):
|
||||
# Model fitness as a weighted combination of metrics
|
||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
return (x[:, :4] * w).sum(1)
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
|
||||
|
||||
def isdocker():
|
||||
# Is environment a Docker container
|
||||
return Path('/workspace').exists() # or Path('/.dockerenv').exists()
|
||||
|
||||
|
||||
def emojis(str=''):
|
||||
# Return platform-dependent emoji-safe version of string
|
||||
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
||||
|
||||
|
||||
def check_online():
|
||||
# Check internet connectivity
|
||||
import socket
|
||||
try:
|
||||
socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def check_git_status():
|
||||
# Recommend 'git pull' if code is out of date
|
||||
print(colorstr('github: '), end='')
|
||||
try:
|
||||
assert Path('.git').exists(), 'skipping check (not a git repository)'
|
||||
assert not isdocker(), 'skipping check (Docker image)'
|
||||
assert check_online(), 'skipping check (offline)'
|
||||
|
||||
cmd = 'git fetch && git config --get remote.origin.url'
|
||||
url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
|
||||
branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
||||
n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
|
||||
if n > 0:
|
||||
s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
|
||||
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
||||
else:
|
||||
s = f'up to date with {url} ✅'
|
||||
print(emojis(s)) # emoji-safe
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def check_requirements(requirements='requirements.txt', exclude=()):
|
||||
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
|
||||
import pkg_resources as pkg
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
if isinstance(requirements, (str, Path)): # requirements.txt file
|
||||
file = Path(requirements)
|
||||
if not file.exists():
|
||||
print(f"{prefix} {file.resolve()} not found, check failed.")
|
||||
return
|
||||
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
|
||||
else: # list or tuple of packages
|
||||
requirements = [x for x in requirements if x not in exclude]
|
||||
|
||||
n = 0 # number of packages updates
|
||||
for r in requirements:
|
||||
try:
|
||||
pkg.require(r)
|
||||
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
|
||||
n += 1
|
||||
print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
|
||||
print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())
|
||||
|
||||
if n: # if packages updated
|
||||
source = file.resolve() if 'file' in locals() else requirements
|
||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
print(emojis(s)) # emoji-safe
|
||||
|
||||
|
||||
def check_img_size(img_size, s=32):
|
||||
# Verify img_size is a multiple of stride s
|
||||
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
|
||||
if new_size != img_size:
|
||||
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
|
||||
return new_size
|
||||
|
||||
|
||||
def check_imshow():
|
||||
# Check if environment supports image displays
|
||||
try:
|
||||
assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
|
||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
|
||||
return False
|
||||
|
||||
|
||||
def check_file(file):
|
||||
# Search for file if not found
|
||||
if Path(file).is_file() or file == '':
|
||||
return file
|
||||
else:
|
||||
files = glob.glob('./**/' + file, recursive=True) # find file
|
||||
assert len(files), f'File Not Found: {file}' # assert file was found
|
||||
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
|
||||
return files[0] # return file
|
||||
|
||||
|
||||
def check_dataset(dict):
|
||||
# Download dataset if not found locally
|
||||
val, s = dict.get('val'), dict.get('download')
|
||||
if val and len(val):
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
|
||||
if s and len(s): # download script
|
||||
print('Downloading %s ...' % s)
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
torch.hub.download_url_to_file(s, f)
|
||||
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
|
||||
else: # bash script
|
||||
r = os.system(s)
|
||||
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
|
||||
else:
|
||||
raise Exception('Dataset not found.')
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
# Returns x evenly divisible by divisor
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
# Cleans a string by replacing special characters with underscore _
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
|
||||
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||
# lambda function for sinusoidal ramp from y1 to y2
|
||||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
colors = {'black': '\033[30m', # basic colors
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
'bright_black': '\033[90m', # bright colors
|
||||
'bright_red': '\033[91m',
|
||||
'bright_green': '\033[92m',
|
||||
'bright_yellow': '\033[93m',
|
||||
'bright_blue': '\033[94m',
|
||||
'bright_magenta': '\033[95m',
|
||||
'bright_cyan': '\033[96m',
|
||||
'bright_white': '\033[97m',
|
||||
'end': '\033[0m', # misc
|
||||
'bold': '\033[1m',
|
||||
'underline': '\033[4m'}
|
||||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||
|
||||
|
||||
def labels_to_class_weights(labels, nc=80):
|
||||
# Get class weights (inverse frequency) from training labels
|
||||
if labels[0] is None: # no labels loaded
|
||||
return torch.Tensor()
|
||||
|
||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
|
||||
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
||||
|
||||
# Prepend gridpoint count (for uCE training)
|
||||
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
|
||||
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
|
||||
|
||||
weights[weights == 0] = 1 # replace empty bins with 1
|
||||
weights = 1 / weights # number of targets per class
|
||||
weights /= weights.sum() # normalize
|
||||
return torch.from_numpy(weights)
|
||||
|
||||
|
||||
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
||||
# Produces image weights based on class_weights and image contents
|
||||
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
|
||||
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
|
||||
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
|
||||
return image_weights
|
||||
|
||||
|
||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
||||
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
||||
return x
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
|
||||
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
|
||||
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
|
||||
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
||||
# Convert normalized segments into pixel segments, shape (n,2)
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = w * x[:, 0] + padw # top left x
|
||||
y[:, 1] = h * x[:, 1] + padh # top left y
|
||||
return y
|
||||
|
||||
|
||||
def segment2box(segment, width=640, height=640):
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = x[inside], y[inside]
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
|
||||
|
||||
|
||||
def segments2boxes(segments):
|
||||
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||
boxes = []
|
||||
for s in segments:
|
||||
x, y = s.T # segment xy
|
||||
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||||
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||||
|
||||
|
||||
def resample_segments(segments, n=1000):
|
||||
# Up-sample an (n,2) segment
|
||||
for i, s in enumerate(segments):
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
||||
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
||||
box2 = box2.T
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
if x1y1x2y2: # x1, y1, x2, y2 = box1
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
||||
else: # transform from xywh to xyxy
|
||||
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
||||
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
||||
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
||||
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
||||
|
||||
# Intersection area
|
||||
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||
|
||||
# Union Area
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
||||
iou = inter / union
|
||||
if GIoU or DIoU or CIoU:
|
||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
||||
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
|
||||
if DIoU:
|
||||
return iou - rho2 / c2 # DIoU
|
||||
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
||||
c_area = cw * ch + eps # convex area
|
||||
return iou - (c_area - union) / c_area # GIoU
|
||||
else:
|
||||
return iou # IoU
|
||||
|
||||
|
||||
def box_iou(box1, box2):
|
||||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Arguments:
|
||||
box1 (Tensor[N, 4])
|
||||
box2 (Tensor[M, 4])
|
||||
Returns:
|
||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
||||
IoU values for every element in boxes1 and boxes2
|
||||
"""
|
||||
|
||||
def box_area(box):
|
||||
# box = 4xn
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
area1 = box_area(box1.T)
|
||||
area2 = box_area(box2.T)
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
||||
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
||||
|
||||
|
||||
def wh_iou(wh1, wh2):
|
||||
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
|
||||
wh1 = wh1[:, None] # [N,1,2]
|
||||
wh2 = wh2[None] # [1,M,2]
|
||||
inter = torch.min(wh1, wh2).prod(2) # [N,M]
|
||||
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
|
||||
labels=()):
|
||||
"""Runs Non-Maximum Suppression (NMS) on inference results
|
||||
|
||||
Returns:
|
||||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
"""
|
||||
|
||||
nc = prediction.shape[2] - 5 # number of classes
|
||||
xc = (prediction[..., 4] > conf_thres) & ( prediction[..., 4] < 1.0000001 ) # candidates
|
||||
|
||||
# Settings
|
||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||
max_det = 300 # maximum number of detections per image
|
||||
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
||||
time_limit = 10.0 # seconds to quit after
|
||||
redundant = True # require redundant detections
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
merge = False # use merge-NMS
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply constraints
|
||||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
l = labels[xi]
|
||||
v = torch.zeros((len(l), nc + 5), device=x.device)
|
||||
v[:, :4] = l[:, 1:5] # box
|
||||
v[:, 4] = 1.0 # conf
|
||||
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Compute conf
|
||||
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
if multi_label:
|
||||
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
||||
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
||||
else: # best class only
|
||||
conf, j = x[:, 5:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
# if not torch.isfinite(x).all():
|
||||
# x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
elif n > max_nms: # excess boxes
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
if i.shape[0] > max_det: # limit detections
|
||||
i = i[:max_det]
|
||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
weights = iou * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
if redundant:
|
||||
i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if (time.time() - t) > time_limit:
|
||||
print(f'WARNING: NMS time limit {time_limit}s exceeded')
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
def overlap_box_suppression(prediction, ovlap_thres = 0.6):
|
||||
"""Runs overlap_box_suppression on inference results
|
||||
delete the box that overlap of boxes bigger than ovlap_thres
|
||||
Returns:
|
||||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
"""
|
||||
def box_iob(box1, box2):
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
area1 = box_area(box1) # (N,)
|
||||
area2 = box_area(box2) # (M,)
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] # N中一个和M个比较;
|
||||
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
|
||||
wh = (rb - lt).clamp(min=0) #小于0的为0 clamp 钳;夹钳;
|
||||
inter = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return torch.squeeze(inter / area1), torch.squeeze(inter / area2)
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction[0].device)] * len(prediction)
|
||||
for i, x in enumerate(prediction):
|
||||
keep = [] # 最终保留的结果, 在boxes中对应的索引;
|
||||
boxes = x[:, 0:4]
|
||||
scores = x[:, 4]
|
||||
cls = x[:, 5]
|
||||
idxs = scores.argsort()
|
||||
while idxs.numel() > 0:
|
||||
keep_idx = idxs[-1]
|
||||
keep_box = boxes[keep_idx][None, ] # [1, 4]
|
||||
keep.append(keep_idx)
|
||||
if idxs.size(0) == 1:
|
||||
break
|
||||
idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算iob;
|
||||
other_boxes = boxes[idxs]
|
||||
this_cls = cls[keep_idx]
|
||||
other_cls = cls[idxs]
|
||||
iobs1, iobs2 = box_iob(keep_box, other_boxes) # 一个框和其余框比较 1XM
|
||||
idxs = idxs[((iobs1 <= ovlap_thres) & (iobs2 <= ovlap_thres)) | (other_cls != this_cls)]
|
||||
keep = idxs.new(keep) # Tensor
|
||||
output[i] = x[keep]
|
||||
return output
|
||||
|
||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
|
||||
# Print mutation results to evolve.txt (for use with train.py --evolve)
|
||||
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
|
||||
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
|
||||
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
|
||||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
||||
|
||||
if bucket:
|
||||
url = 'gs://%s/evolve.txt' % bucket
|
||||
if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
|
||||
os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
|
||||
|
||||
with open('evolve.txt', 'a') as f: # append result
|
||||
f.write(c + b + '\n')
|
||||
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
|
||||
x = x[np.argsort(-fitness(x))] # sort
|
||||
np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
|
||||
|
||||
# Save yaml
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
hyp[k] = float(x[0, i + 7])
|
||||
with open(yaml_file, 'w') as f:
|
||||
results = tuple(x[0, :7])
|
||||
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
|
||||
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
|
||||
yaml.dump(hyp, f, sort_keys=False)
|
||||
|
||||
if bucket:
|
||||
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
|
||||
|
||||
|
||||
def apply_classifier(x, model, img, im0):
|
||||
# applies a second stage classifier to yolo outputs
|
||||
im0 = [im0] if isinstance(im0, np.ndarray) else im0
|
||||
for i, d in enumerate(x): # per image
|
||||
if d is not None and len(d):
|
||||
d = d.clone()
|
||||
|
||||
# Reshape and pad cutouts
|
||||
b = xyxy2xywh(d[:, :4]) # boxes
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
|
||||
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
|
||||
d[:, :4] = xywh2xyxy(b).long()
|
||||
|
||||
# Rescale boxes from img_size to im0 size
|
||||
imgHelper.scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
|
||||
|
||||
# Classes
|
||||
pred_cls1 = d[:, 5].long()
|
||||
ims = []
|
||||
for j, a in enumerate(d): # per item
|
||||
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
|
||||
im = cv2.resize(cutout, (224, 224)) # BGR
|
||||
# cv2.imwrite('test%i.jpg' % j, cutout)
|
||||
|
||||
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
|
||||
im /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
ims.append(im)
|
||||
|
||||
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
|
||||
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=True, sep=''):
|
||||
# Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
|
||||
path = Path(path) # os-agnostic
|
||||
if (path.exists() and exist_ok) or (not path.exists()):
|
||||
return str(path)
|
||||
else:
|
||||
dirs = glob.glob(f"{path}{sep}*") # similar paths
|
||||
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
|
||||
i = [int(m.groups()[0]) for m in matches if m] # indices
|
||||
n = max(i) + 1 if i else 2 # increment number
|
||||
return f"{path}{sep}{n}" # update path
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
# Google utils: https://cloud.google.com/storage/docs/reference/libraries
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def gsutil_getsize(url=''):
|
||||
"""
|
||||
获取Google Cloud Storage中指定URL的文件大小
|
||||
|
||||
参数:
|
||||
url (str): Google Cloud Storage的URL路径,格式为gs://bucket/file
|
||||
|
||||
返回值:
|
||||
int: 文件大小(以字节为单位),如果URL为空或获取失败则返回0
|
||||
|
||||
注意:
|
||||
需要系统已安装并配置好gsutil工具
|
||||
"""
|
||||
# gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
|
||||
s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
|
||||
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
||||
|
||||
|
||||
def attempt_download(file, repo='ultralytics/yolov5'):
|
||||
# Attempt file download if does not exist
|
||||
file = Path(str(file).strip().replace("'", '').lower())
|
||||
|
||||
if not file.exists():
|
||||
try:
|
||||
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
|
||||
assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...]
|
||||
tag = response['tag_name'] # i.e. 'v1.0'
|
||||
except: # fallback plan
|
||||
assets = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']
|
||||
tag = subprocess.check_output('git tag', shell=True).decode().split()[-1]
|
||||
|
||||
name = file.name
|
||||
if name in assets:
|
||||
msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
|
||||
redundant = False # second download option
|
||||
try: # GitHub
|
||||
url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
|
||||
print(f'Downloading {url} to {file}...')
|
||||
torch.hub.download_url_to_file(url, file)
|
||||
assert file.exists() and file.stat().st_size > 1E6 # check
|
||||
except Exception as e: # GCP
|
||||
print(f'Download error: {e}')
|
||||
assert redundant, 'No secondary mirror'
|
||||
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
|
||||
print(f'Downloading {url} to {file}...')
|
||||
os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights)
|
||||
finally:
|
||||
if not file.exists() or file.stat().st_size < 1E6: # check
|
||||
file.unlink(missing_ok=True) # remove partial downloads
|
||||
print(f'ERROR: Download failure: {msg}')
|
||||
print('')
|
||||
return
|
||||
|
||||
|
||||
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
||||
# Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download()
|
||||
t = time.time()
|
||||
file = Path(file)
|
||||
cookie = Path('cookie') # gdrive cookie
|
||||
print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
|
||||
file.unlink(missing_ok=True) # remove existing file
|
||||
cookie.unlink(missing_ok=True) # remove existing cookie
|
||||
|
||||
# Attempt file download
|
||||
out = "NUL" if platform.system() == "Windows" else "/dev/null"
|
||||
os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}')
|
||||
if os.path.exists('cookie'): # large file
|
||||
s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}'
|
||||
else: # small file
|
||||
s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
|
||||
r = os.system(s) # execute, capture return
|
||||
cookie.unlink(missing_ok=True) # remove existing cookie
|
||||
|
||||
# Error check
|
||||
if r != 0:
|
||||
file.unlink(missing_ok=True) # remove partial
|
||||
print('Download error ') # raise Exception('Download error')
|
||||
return r
|
||||
|
||||
# Unzip if archive
|
||||
if file.suffix == '.zip':
|
||||
print('unzipping... ', end='')
|
||||
os.system(f'unzip -q {file}') # unzip
|
||||
file.unlink() # remove zip to free space
|
||||
|
||||
print(f'Done ({time.time() - t:.1f}s)')
|
||||
return r
|
||||
|
||||
|
||||
def get_token(cookie="./cookie"):
|
||||
with open(cookie) as f:
|
||||
for line in f:
|
||||
if "download" in line:
|
||||
return line.split()[-1]
|
||||
return ""
|
||||
|
||||
# def upload_blob(bucket_name, source_file_name, destination_blob_name):
|
||||
# # Uploads a file to a bucket
|
||||
# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
|
||||
#
|
||||
# storage_client = storage.Client()
|
||||
# bucket = storage_client.get_bucket(bucket_name)
|
||||
# blob = bucket.blob(destination_blob_name)
|
||||
#
|
||||
# blob.upload_from_filename(source_file_name)
|
||||
#
|
||||
# print('File {} uploaded to {}.'.format(
|
||||
# source_file_name,
|
||||
# destination_blob_name))
|
||||
#
|
||||
#
|
||||
# def download_blob(bucket_name, source_blob_name, destination_file_name):
|
||||
# # Uploads a blob from a bucket
|
||||
# storage_client = storage.Client()
|
||||
# bucket = storage_client.get_bucket(bucket_name)
|
||||
# blob = bucket.blob(source_blob_name)
|
||||
#
|
||||
# blob.download_to_filename(destination_file_name)
|
||||
#
|
||||
# print('Blob {} downloaded to {}.'.format(
|
||||
# source_blob_name,
|
||||
# destination_file_name))
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
from kafka import KafkaProducer, KafkaConsumer,TopicPartition
|
||||
from kafka.errors import kafka_errors
|
||||
import os,cv2,sys,json,time
|
||||
import numpy as np
|
||||
import requests
|
||||
def query_channel_status(channelIndex):
|
||||
channel_query_api='https://streaming.t-aaron.com/livechannel/getLiveStatus/%s'%(channelIndex)
|
||||
#https://streaming.t-aaron.com/livechannel/getLiveStatus/LC001
|
||||
try:
|
||||
res = requests.get(channel_query_api,timeout=10).json()
|
||||
if res['data']['status']==2:#1空闲中 2使用中 3停用 4待关闭
|
||||
taskEnd=False
|
||||
else:
|
||||
taskEnd=True
|
||||
infos='channel_query_api connected'
|
||||
except Exception as e:
|
||||
taskEnd=True
|
||||
infos='channel_query_api not connected:%s'%(e)
|
||||
return infos, taskEnd
|
||||
|
||||
def query_request_status(request_url):
|
||||
#channel_query_api='https://streaming.t-aaron.com/livechannel/getLiveStatus/%s'%(channelIndex)
|
||||
channel_request_api=request_url
|
||||
|
||||
try:
|
||||
res = requests.get(channel_request_api,timeout=10).json()
|
||||
if res['data']['status']==5:#5:执行中 10:待停止分析 15:执行结束
|
||||
taskEnd=False
|
||||
else:
|
||||
taskEnd=True
|
||||
infos='channel_request_api connected'
|
||||
except Exception as e:
|
||||
taskEnd=True
|
||||
infos='channel_request_api not connected:%s'%(e)
|
||||
return infos, taskEnd
|
||||
|
||||
def get_needed_objectsIndex(object_config):
|
||||
needed_objectsIndex=[]
|
||||
|
||||
for model in object_config:
|
||||
try:
|
||||
needed_objectsIndex.append(int(model['id']))
|
||||
except Exception as e:
|
||||
a=1
|
||||
allowedList_str=[str(x) for x in needed_objectsIndex]
|
||||
allowedList_string=','.join(allowedList_str)
|
||||
|
||||
return needed_objectsIndex , allowedList_string
|
||||
|
||||
|
||||
def get_infos(taskId, msgId,msg_h,key_str='waiting stream or video, send heartbeat'):
|
||||
outStrList={}
|
||||
outStrList['success']= '%s, taskId:%s msgId:%s send:%s'%(key_str,taskId, msgId,msg_h);
|
||||
outStrList['failure']='kafka ERROR, %s'%(key_str)
|
||||
outStrList['Refailure']='kafka Re-send ERROR ,%s'%(key_str)
|
||||
return outStrList
|
||||
def writeTxtEndFlag(outImaDir,streamName,imageTxtFile,endFlag='结束'):
|
||||
#time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
EndUrl='%s/%s_frame-9999-9999_type-%s_9999999999999999_s-%s_AI.jpg'%(outImaDir,time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),endFlag,streamName)
|
||||
EndUrl = EndUrl.replace(' ','-').replace(':','-')
|
||||
img_end=np.zeros((100,100),dtype=np.uint8);cv2.imwrite(EndUrl,img_end)
|
||||
if imageTxtFile:
|
||||
EndUrl_txt = EndUrl.replace('.jpg','.txt')
|
||||
fp_t=open(EndUrl_txt,'w');fp_t.write(EndUrl+'\n');fp_t.close()
|
||||
|
||||
EndUrl='%s/%s_frame-9999-9999_type-%s_9999999999999999_s-%s_OR.jpg'%(outImaDir,time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),endFlag,streamName)
|
||||
EndUrl = EndUrl.replace(' ','-').replace(':','-')
|
||||
ret = cv2.imwrite(EndUrl,img_end)
|
||||
if imageTxtFile:
|
||||
EndUrl_txt = EndUrl.replace('.jpg','.txt')
|
||||
fp_t=open(EndUrl_txt,'w');fp_t.write(EndUrl+'\n');fp_t.close()
|
||||
def get_current_time():
|
||||
"""[summary] 获取当前时间
|
||||
|
||||
[description] 用time.localtime()+time.strftime()实现
|
||||
:returns: [description] 返回str类型
|
||||
"""
|
||||
ct = time.time()
|
||||
local_time = time.localtime(ct)
|
||||
data_head = time.strftime("%Y-%m-%d %H:%M:%S", local_time)
|
||||
data_secs = (ct - int(ct)) * 1000
|
||||
time_stamp = "%s.%03d" % (data_head, data_secs)
|
||||
return time_stamp
|
||||
|
||||
|
||||
|
||||
def send_kafka(producer,par,msg,outStrList,fp_log,logger,line='000',thread='detector',printFlag=False ):
|
||||
future = producer.send(par['topic'], msg)
|
||||
try:
|
||||
record_metadata = future.get()
|
||||
outstr=outStrList['success']
|
||||
|
||||
#outstr=wrtiteLog(fp_log,outstr);print( outstr);
|
||||
writeELK_log(outstr,fp_log,level='INFO',thread=thread,line=line,logger=logger,printFlag=printFlag)
|
||||
|
||||
except Exception as e:
|
||||
outstr='%s , warning: %s'%( outStrList['failure'],str(e))
|
||||
writeELK_log(outstr,fp_log,level='WARNING',thread=thread,line=line,logger=logger,printFlag=printFlag)
|
||||
try:
|
||||
producer.close()
|
||||
producer = KafkaProducer(bootstrap_servers=par['server'], value_serializer=lambda v: v.encode('utf-8')).get()
|
||||
future = producer.send(par['topic'], msg).get()
|
||||
except Exception as e:
|
||||
outstr='%s, error: %s'%( outStrList['Refailure'],str(e))
|
||||
#outstr=wrtiteLog(fp_log,outstr);print( outstr);
|
||||
writeELK_log(outstr,fp_log,level='ERROR',thread=thread,line=line,logger=logger,printFlag=printFlag)
|
||||
|
||||
def check_time_interval(time0_beg,time_interval):
|
||||
time_2 = time.time()
|
||||
if time_2 - time0_beg>time_interval:
|
||||
return time_2,True
|
||||
else:
|
||||
return time0_beg,False
|
||||
def addTime(strs):
|
||||
timestr=time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime())
|
||||
|
||||
outstr='\n %s %s '%(timestr,strs)
|
||||
return
|
||||
|
||||
|
||||
def get_file():
|
||||
print("文件名 :",__file__,sys._getframe().f_lineno)
|
||||
print("函数名: ", sys._getframe().f_code.co_name)
|
||||
print("模块名: ", sys._getframe().f_back.f_code.co_name)
|
||||
|
||||
def writeELK_log(msg,fp,level='INFO',thread='detector',logger='kafka_yolov5',line=9999,newLine=False,printFlag=True):
|
||||
#timestr=time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime())
|
||||
timestr=get_current_time()
|
||||
outstr='%s [%s][%s][%d][%s]- %s'%(timestr,level,thread,line,logger,msg)
|
||||
|
||||
if newLine:
|
||||
outstr = '\n'+outstr
|
||||
|
||||
fp.write(outstr+'\n')
|
||||
fp.flush()
|
||||
if printFlag:
|
||||
print(outstr)
|
||||
return outstr
|
||||
|
||||
|
||||
def wrtiteLog(fp,strs,newLine=False):
|
||||
timestr=time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime())
|
||||
if newLine:
|
||||
outstr='\n %s %s '%(timestr,strs)
|
||||
else:
|
||||
outstr='%s %s '%(timestr,strs)
|
||||
fp.write(outstr+'\n')
|
||||
fp.flush()
|
||||
return outstr
|
||||
|
||||
def create_logFile(logdir='logdir',name=None):
|
||||
if name:
|
||||
logname =logdir+'/'+ name
|
||||
else:
|
||||
logname =logdir+'/'+ time.strftime("%Y-%m-%d.txt", time.localtime())
|
||||
if os.path.exists(logname):
|
||||
fp_log = open(logname,'a+')
|
||||
else:
|
||||
fp_log = open(logname,'w')
|
||||
return fp_log
|
||||
def get_boradcast_address(outResource):
|
||||
#rtmp://live.push.t-aaron.com/live/THSB,阿里云,1945
|
||||
#rtmp://demopush.yunhengzhizao.cn/live/THSB,腾讯云,1935
|
||||
if '1945' in outResource:
|
||||
return 'rtmp://live.play.t-aaron.com/live/THSB'
|
||||
else:
|
||||
return 'rtmp://demoplay.yunhengzhizao.cn/live/THSB_HD5M'
|
||||
def save_message(kafka_dir,msg):
|
||||
outtxt=os.path.join(kafka_dir,msg['request_id']+'.json')
|
||||
assert os.path.exists(kafka_dir)
|
||||
with open(outtxt,'w') as fp:
|
||||
json.dump(msg,fp,ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
def get_push_address(outResource):
|
||||
#rtmp://live.push.t-aaron.com/live/THSB,阿里云,1945
|
||||
#rtmp://demopush.yunhengzhizao.cn/live/THSB,腾讯云,1935
|
||||
#终端推流地址:rtmp://live.push.t-aaron.com/live/THSAa
|
||||
#终端拉流地址:rtmp://live.play.t-aaron.com/live/THSAa_hd
|
||||
#AI推流地址:rtmp://live.push.t-aaron.com/live/THSBa
|
||||
#AI拉流地址:rtmp://live.play.t-aaron.com/live/THSBa_hd
|
||||
|
||||
if 't-aaron' in outResource:
|
||||
if 'THSBa' in outResource: port=1975
|
||||
elif 'THSBb' in outResource: port=1991
|
||||
elif 'THSBc' in outResource: port=1992
|
||||
elif 'THSBd' in outResource: port=1993
|
||||
elif 'THSBe' in outResource: port=1994
|
||||
elif 'THSBf' in outResource: port=1995
|
||||
elif 'THSBg' in outResource: port=1996
|
||||
elif 'THSBh' in outResource: port=1997
|
||||
else: port=1945
|
||||
else: port=1935
|
||||
return 'rtmp://127.0.0.1:%d/live/test'%(port)
|
||||
return outResource
|
||||
def getAllRecord_poll(consumer):
|
||||
msgs = consumer.poll(5000)
|
||||
keys=msgs.keys()
|
||||
out = [ msgs[x] for x in keys]
|
||||
out = [y for x in out for y in x]
|
||||
|
||||
|
||||
for key in keys:
|
||||
out.extend(msgs[key])
|
||||
return out
|
||||
def getAllRecords(consumer,topics):
|
||||
leftCnt = 0
|
||||
for topic in topics[0:2]:
|
||||
leftCnt+=get_left_cnt(consumer,topic)
|
||||
out = []
|
||||
if leftCnt == 0:
|
||||
return []
|
||||
for ii,msg in enumerate(consumer):
|
||||
consumer.commit()
|
||||
out.append(msg)
|
||||
if ii== (leftCnt-1):
|
||||
break###断流或者到终点
|
||||
return out
|
||||
|
||||
def get_left_cnt(consumer,topic):
|
||||
partitions = [TopicPartition(topic, p) for p in consumer.partitions_for_topic(topic)]
|
||||
|
||||
# total
|
||||
toff = consumer.end_offsets(partitions)
|
||||
toff = [(key.partition, toff[key]) for key in toff.keys()]
|
||||
toff.sort()
|
||||
|
||||
# current
|
||||
coff = [(x.partition, consumer.committed(x)) for x in partitions]
|
||||
coff.sort()
|
||||
|
||||
# cal sum and left
|
||||
toff_sum = sum([x[1] for x in toff])
|
||||
cur_sum = sum([x[1] for x in coff if x[1] is not None])
|
||||
left_sum = toff_sum - cur_sum
|
||||
|
||||
return left_sum
|
||||
def view_bar(num, total,time1,prefix='prefix'):
|
||||
rate = num / total
|
||||
time_n=time.time()
|
||||
rate_num = int(rate * 30)
|
||||
rate_nums = np.round(rate * 100)
|
||||
r = '\r %s %d / %d [%s%s] %.2f s'%(prefix,num,total, ">" * rate_num, " " * (30 - rate_num), time_n-time1 )
|
||||
sys.stdout.write(r)
|
||||
sys.stdout.flush()
|
||||
def get_total_cnt(inSource):
|
||||
cap=cv2.VideoCapture(inSource)
|
||||
assert cap.isOpened()
|
||||
cnt=cap.get(7)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
return cnt,fps
|
||||
def check_stream(inSource,producer,par,msg,outStrList ,fp_log,logger,line='000',thread='detector',timeMs=120,):
|
||||
cnt =(timeMs-1)//10 + 1
|
||||
Stream_ok=False
|
||||
|
||||
for icap in range(cnt):
|
||||
cap=cv2.VideoCapture(inSource)
|
||||
|
||||
if cap.isOpened() and get_fps_rtmp(inSource,video=False)[0] :
|
||||
Stream_ok=True ;cap.release();break;
|
||||
#Stream_ok,_= get_fps_rtmp(inSource,video=False)
|
||||
#if Stream_ok:cap.release();break;
|
||||
else:
|
||||
Stream_ok=False
|
||||
timestr=time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime())
|
||||
outstr='Waiting stream %d s'%(10*icap)
|
||||
writeELK_log(msg=outstr,fp=fp_log,thread=thread,line=line,logger=logger)
|
||||
time.sleep(10)
|
||||
if icap%3==0:
|
||||
send_kafka(producer,par,msg,outStrList,fp_log,logger=logger,line=line,thread=thread )
|
||||
|
||||
|
||||
return Stream_ok
|
||||
|
||||
|
||||
|
||||
|
||||
def get_fps_rtmp(inSource,video=False):
|
||||
cap=cv2.VideoCapture(inSource)
|
||||
if not cap.isOpened():
|
||||
print('#####error url:',inSource)
|
||||
return False,[0,0,0,0]
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH )
|
||||
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
cnt = 0
|
||||
if video: cnt=cap.get(7)
|
||||
|
||||
if width*height==0 or fps>30:
|
||||
return False,[0,0,0,0]
|
||||
cap.release()
|
||||
try:
|
||||
outx = [fps,width,height,cnt]
|
||||
outx = [int(x+0.5) for x in outx]
|
||||
|
||||
return True,outx
|
||||
except:
|
||||
return False, [0,0,0,0]
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
# Model validation metrics
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import general
|
||||
|
||||
|
||||
def fitness(x):
|
||||
# Model fitness as a weighted combination of metrics
|
||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
return (x[:, :4] * w).sum(1)
|
||||
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
|
||||
""" Compute the average precision, given the recall and precision curves.
|
||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
||||
# Arguments
|
||||
tp: True positives (nparray, nx1 or nx10).
|
||||
conf: Objectness value from 0-1 (nparray).
|
||||
pred_cls: Predicted object classes (nparray).
|
||||
target_cls: True object classes (nparray).
|
||||
plot: Plot precision-recall curve at mAP@0.5
|
||||
save_dir: Plot save directory
|
||||
# Returns
|
||||
The average precision as computed in py-faster-rcnn.
|
||||
"""
|
||||
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
||||
|
||||
# Find unique classes
|
||||
unique_classes = np.unique(target_cls)
|
||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
||||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
||||
for ci, c in enumerate(unique_classes):
|
||||
i = pred_cls == c
|
||||
n_l = (target_cls == c).sum() # number of labels
|
||||
n_p = i.sum() # number of predictions
|
||||
|
||||
if n_p == 0 or n_l == 0:
|
||||
continue
|
||||
else:
|
||||
# Accumulate FPs and TPs
|
||||
fpc = (1 - tp[i]).cumsum(0)
|
||||
tpc = tp[i].cumsum(0)
|
||||
|
||||
# Recall
|
||||
recall = tpc / (n_l + 1e-16) # recall curve
|
||||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
||||
|
||||
# Precision
|
||||
precision = tpc / (tpc + fpc) # precision curve
|
||||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
||||
if plot and j == 0:
|
||||
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
||||
|
||||
# Compute F1 (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + 1e-16)
|
||||
if plot:
|
||||
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
|
||||
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
|
||||
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
|
||||
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
|
||||
|
||||
i = f1.mean(0).argmax() # max F1 index
|
||||
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
""" Compute the average precision, given the recall and precision curves
|
||||
# Arguments
|
||||
recall: The recall curve (list)
|
||||
precision: The precision curve (list)
|
||||
# Returns
|
||||
Average precision, precision curve, recall curve
|
||||
"""
|
||||
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
|
||||
mpre = np.concatenate(([1.], precision, [0.]))
|
||||
|
||||
# Compute the precision envelope
|
||||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
||||
|
||||
# Integrate area under curve
|
||||
method = 'interp' # methods: 'continuous', 'interp'
|
||||
if method == 'interp':
|
||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
||||
else: # 'continuous'
|
||||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
||||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
||||
|
||||
return ap, mpre, mrec
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45):
|
||||
self.matrix = np.zeros((nc + 1, nc + 1))
|
||||
self.nc = nc # number of classes
|
||||
self.conf = conf
|
||||
self.iou_thres = iou_thres
|
||||
|
||||
def process_batch(self, detections, labels):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Arguments:
|
||||
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (Array[M, 5]), class, x1, y1, x2, y2
|
||||
Returns:
|
||||
None, updates confusion matrix accordingly
|
||||
"""
|
||||
detections = detections[detections[:, 4] > self.conf]
|
||||
gt_classes = labels[:, 0].int()
|
||||
detection_classes = detections[:, 5].int()
|
||||
iou = general.box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where(iou > self.iou_thres)
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
else:
|
||||
matches = np.zeros((0, 3))
|
||||
|
||||
n = matches.shape[0] > 0
|
||||
m0, m1, _ = matches.transpose().astype(np.int16)
|
||||
for i, gc in enumerate(gt_classes):
|
||||
j = m0 == i
|
||||
if n and sum(j) == 1:
|
||||
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
|
||||
else:
|
||||
self.matrix[self.nc, gc] += 1 # background FP
|
||||
|
||||
if n:
|
||||
for i, dc in enumerate(detection_classes):
|
||||
if not any(m1 == i):
|
||||
self.matrix[dc, self.nc] += 1 # background FN
|
||||
|
||||
def matrix(self):
|
||||
return self.matrix
|
||||
|
||||
def plot(self, save_dir='', names=()):
|
||||
try:
|
||||
import seaborn as sn
|
||||
|
||||
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
||||
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
||||
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
|
||||
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
|
||||
xticklabels=names + ['background FP'] if labels else "auto",
|
||||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||
fig.axes[0].set_xlabel('True')
|
||||
fig.axes[0].set_ylabel('Predicted')
|
||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def print(self):
|
||||
for i in range(self.nc + 1):
|
||||
print(' '.join(map(str, self.matrix[i])))
|
||||
|
||||
|
||||
# Plots ----------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
|
||||
# Precision-recall curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py.T):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
||||
else:
|
||||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
||||
|
||||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
||||
ax.set_xlabel('Recall')
|
||||
ax.set_ylabel('Precision')
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
fig.savefig(Path(save_dir), dpi=250)
|
||||
|
||||
|
||||
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
|
||||
# Metric-confidence curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
||||
else:
|
||||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
||||
|
||||
y = py.mean(0)
|
||||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
fig.savefig(Path(save_dir), dpi=250)
|
||||
|
|
@ -0,0 +1,613 @@
|
|||
# Plotting utils
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os,sys
|
||||
import random
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
import yaml
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from scipy.signal import butter, filtfilt,savgol_filter
|
||||
|
||||
from util.general import xywh2xyxy, xyxy2xywh
|
||||
from util.metrics import fitness
|
||||
|
||||
# Settings
|
||||
matplotlib.rc('font', **{'size': 11})
|
||||
#matplotlib.use('Agg') # for writing to files only
|
||||
|
||||
def smooth_outline(contours,p1,p2):
|
||||
arcontours=np.array(contours)
|
||||
coors_x=arcontours[0,:,0,0]
|
||||
coors_y=arcontours[0,:,0,1]
|
||||
coors_x_smooth= savgol_filter(coors_x,p1,p2)
|
||||
coors_y_smooth= savgol_filter(coors_y,p1,p2)
|
||||
arcontours[0,:,0,0] = coors_x_smooth
|
||||
arcontours[0,:,0,1] = coors_y_smooth
|
||||
return arcontours
|
||||
def smooth_outline_auto(contours):
|
||||
cnt = len(contours[0])
|
||||
p1 = int(cnt/12)*2+1
|
||||
p2 =3
|
||||
if p1<p2: p2 = p1-1
|
||||
return smooth_outline(contours,p1,p2)
|
||||
|
||||
def get_websource(txtfile):
|
||||
with open(txtfile,'r') as fp:
|
||||
lines = fp.readlines()
|
||||
webs=[];ports=[];streamNames=[]
|
||||
for line in lines:
|
||||
try:
|
||||
sps = line.strip().split(' ')
|
||||
webs.append(sps[0])
|
||||
#rtmp://liveplay.yunhengzhizao.cn/live/demo_HD5M
|
||||
if 'rtmp' in sps[0]:
|
||||
name = sps[0].split('/')[4].split('_')[0]
|
||||
else:
|
||||
name = sps[0][-3:]
|
||||
ports.append(sps[1])
|
||||
streamNames.append(name)
|
||||
|
||||
except:
|
||||
|
||||
print('####format error : %s , in file:%s#####'%(line,txtfile))
|
||||
assert len(webs)>0
|
||||
return webs,ports,streamNames
|
||||
|
||||
|
||||
def color_list():
|
||||
# Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
|
||||
def hex2rgb(h):
|
||||
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
||||
|
||||
return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
|
||||
|
||||
|
||||
def hist2d(x, y, n=100):
|
||||
# 2d histogram used in labels.png and evolve.png
|
||||
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
||||
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
||||
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
||||
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
||||
return np.log(hist[xidx, yidx])
|
||||
|
||||
|
||||
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
||||
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
||||
def butter_lowpass(cutoff, fs, order):
|
||||
nyq = 0.5 * fs
|
||||
normal_cutoff = cutoff / nyq
|
||||
return butter(order, normal_cutoff, btype='low', analog=False)
|
||||
|
||||
b, a = butter_lowpass(cutoff, fs, order=order)
|
||||
return filtfilt(b, a, data) # forward-backward filter
|
||||
|
||||
|
||||
'''image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image)
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
font = ImageFont.truetype('./font/platech.ttf', 40, encoding='utf-8')
|
||||
for info in infos:
|
||||
detect = info['bndbox']
|
||||
text = ','.join(list(info['attributes'].values()))
|
||||
temp = -50
|
||||
if info['name'] == 'vehicle':
|
||||
temp = 20
|
||||
draw.text((detect[0], detect[1] + temp), text, (0, 255, 255), font=font)
|
||||
if 'scores' in info:
|
||||
draw.text((detect[0], detect[3]), info['scores'], (0, 255, 0), font=font)
|
||||
if 'pscore' in info:
|
||||
draw.text((detect[2], detect[3]), str(round(info['pscore'],3)), (0, 255, 0), font=font)
|
||||
image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
for info in infos:
|
||||
detect = info['bndbox']
|
||||
cv2.rectangle(image, (detect[0], detect[1]), (detect[2], detect[3]), (0, 255, 0), 1, cv2.LINE_AA)
|
||||
return image'''
|
||||
|
||||
'''def plot_one_box_PIL(x, im, color=None, label=None, line_thickness=3):
|
||||
# Plots one bounding box on image 'im' using OpenCV
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
|
||||
tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
|
||||
color = color or [random.randint(0, 255) for _ in range(3)]
|
||||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
||||
|
||||
|
||||
|
||||
cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||||
|
||||
|
||||
if label:
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
||||
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
|
||||
|
||||
im = Image.fromarray(im)
|
||||
draw = ImageDraw.Draw(im)
|
||||
font = ImageFont.truetype('./font/platech.ttf', t_size, encoding='utf-8')
|
||||
draw.text((c1[0], c1[1] - 2), label, (0, 255, 0), font=font)
|
||||
|
||||
#cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
return np.array(im) '''
|
||||
|
||||
def plot_one_box(x, im, color=None, label=None, line_thickness=3):
|
||||
# Plots one bounding box on image 'im' using OpenCV
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
|
||||
tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
|
||||
color = color or [random.randint(0, 255) for _ in range(3)]
|
||||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
||||
cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||||
|
||||
if label:
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
||||
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
|
||||
def plot_one_box_PIL(box, im, color=None, label=None, line_thickness=None):
|
||||
# Plots one bounding box on image 'im' using PIL
|
||||
|
||||
im = Image.fromarray(im)
|
||||
draw = ImageDraw.Draw(im)
|
||||
line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
|
||||
draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
|
||||
|
||||
if label:
|
||||
fontsize = max(round(max(im.size) / 40), 12)
|
||||
font = ImageFont.truetype("../AIlib2/conf/platech.ttf", fontsize,encoding='utf-8')
|
||||
txt_width, txt_height = font.getsize(label)
|
||||
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
|
||||
draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
|
||||
im_array = np.asarray(im)
|
||||
|
||||
return np.asarray(im)
|
||||
|
||||
# def draw_painting_joint(box,img,label_array,score=0.5,color=None,font={ 'line_thickness':None,'boxLine_thickness':None, 'fontSize':None},socre_location="leftTop"):
|
||||
# #如果box[0]不是list or 元组,则box是[ (x0,y0),(x1,y1),(x2,y2),(x3,y3)]四点格式
|
||||
# if isinstance(box[0], (list, tuple,np.ndarray ) ):
|
||||
# ###先把中文类别字体赋值到img中
|
||||
# lh, lw, lc = label_array.shape
|
||||
# imh, imw, imc = img.shape
|
||||
# if socre_location=='leftTop':
|
||||
# x0 , y1 = box[0][0],box[0][1]
|
||||
# elif socre_location=='leftBottom':
|
||||
# x0,y1=box[3][0],box[3][1]
|
||||
# else:
|
||||
# print('plot.py line217 ,label_location:%s not implemented '%( socre_location ))
|
||||
# sys.exit(0)
|
||||
|
||||
# x1 , y0 = x0 + lw , y1 - lh
|
||||
# if y0<0:y0=0;y1=y0+lh
|
||||
# if y1>imh: y1=imh;y0=y1-lh
|
||||
# if x0<0:x0=0;x1=x0+lw
|
||||
# if x1>imw:x1=imw;x0=x1-lw
|
||||
# img[y0:y1,x0:x1,:] = label_array
|
||||
# pts_cls=[(x0,y0),(x1,y1) ]
|
||||
|
||||
# #把四边形的框画上
|
||||
# box_tl= font['boxLine_thickness'] or round(0.002 * (imh + imw) / 2) + 1
|
||||
# cv2.polylines(img, [box], True,color , box_tl)
|
||||
|
||||
# ####把英文字符score画到类别旁边
|
||||
# tl = font['line_thickness'] or round(0.002*(imh+imw)/2)+1#line/font thickness
|
||||
# label = ' %.2f'%(score)
|
||||
# tf = max(tl , 1) # font thickness
|
||||
# fontScale = font['fontSize'] or tl * 0.33
|
||||
# t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
|
||||
|
||||
|
||||
# #if socre_location=='leftTop':
|
||||
# p1,p2= (pts_cls[1][0], pts_cls[0][1]),(pts_cls[1][0]+t_size[0],pts_cls[1][1])
|
||||
# cv2.rectangle(img, p1 , p2, color, -1, cv2.LINE_AA)
|
||||
# p3 = pts_cls[1][0],pts_cls[1][1]-(lh-t_size[1])//2
|
||||
|
||||
# cv2.putText(img, label,p3, 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
# return img
|
||||
# else:####两点格式[x0,y0,x1,y1]
|
||||
# try:
|
||||
# box = [int(xx.cpu()) for xx in box]
|
||||
# except:
|
||||
# box=[ int(x) for x in box]
|
||||
# ###先把中文类别字体赋值到img中
|
||||
# lh, lw, lc = label_array.shape
|
||||
# imh, imw, imc = img.shape
|
||||
# if socre_location=='leftTop':
|
||||
# x0 , y1 = box[0:2]
|
||||
# elif socre_location=='leftBottom':
|
||||
# x0,y1=box[0],box[3]
|
||||
# else:
|
||||
# print('plot.py line217 ,socre_location:%s not implemented '%( socre_location ))
|
||||
# sys.exit(0)
|
||||
# x1 , y0 = x0 + lw , y1 - lh
|
||||
# if y0<0:y0=0;y1=y0+lh
|
||||
# if y1>imh: y1=imh;y0=y1-lh
|
||||
# if x0<0:x0=0;x1=x0+lw
|
||||
# if x1>imw:x1=imw;x0=x1-lw
|
||||
# img[y0:y1,x0:x1,:] = label_array
|
||||
|
||||
|
||||
|
||||
# ###把矩形框画上,指定颜色和线宽
|
||||
# tl = font['line_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
# box_tl= font['boxLine_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
|
||||
# c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
# cv2.rectangle(img, c1, c2, color, thickness=box_tl, lineType=cv2.LINE_AA)
|
||||
|
||||
# ###把英文字符score画到类别旁边
|
||||
# label = ' %.2f'%(score)
|
||||
# tf = max(tl , 1) # font thickness
|
||||
# fontScale = font['fontSize'] or tl * 0.33
|
||||
# t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
|
||||
|
||||
# if socre_location=='leftTop':
|
||||
# c2 = c1[0]+ lw + t_size[0], c1[1] - lh
|
||||
# cv2.rectangle(img, (int(box[0])+lw,int(box[1])) , c2, color, -1, cv2.LINE_AA) # filled
|
||||
# cv2.putText(img, label, (c1[0]+lw, c1[1] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
# elif socre_location=='leftBottom':
|
||||
# c2 = box[0]+ lw + t_size[0], box[3] - lh
|
||||
# cv2.rectangle(img, (int(box[0])+lw,int(box[3])) , c2, color, -1, cv2.LINE_AA) # filled
|
||||
# cv2.putText(img, label, ( box[0] + lw, box[3] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
# #print('#####line224 fontScale:',fontScale,' thickness:',tf,' line_thickness:',font['line_thickness'],' boxLine thickness:',box_tl)
|
||||
# return img
|
||||
|
||||
def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
|
||||
# Compares the two methods for width-height anchor multiplication
|
||||
# https://github.com/ultralytics/yolov3/issues/168
|
||||
x = np.arange(-4.0, 4.0, .1)
|
||||
ya = np.exp(x)
|
||||
yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
|
||||
|
||||
fig = plt.figure(figsize=(6, 3), tight_layout=True)
|
||||
plt.plot(x, ya, '.-', label='YOLOv3')
|
||||
plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
|
||||
plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
|
||||
plt.xlim(left=-4, right=4)
|
||||
plt.ylim(bottom=0, top=6)
|
||||
plt.xlabel('input')
|
||||
plt.ylabel('output')
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
fig.savefig('comparison.png', dpi=200)
|
||||
|
||||
|
||||
def output_to_target(output):
|
||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
for *box, conf, cls in o.cpu().numpy():
|
||||
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
|
||||
return np.array(targets)
|
||||
|
||||
|
||||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
|
||||
# Plot image grid with labels
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
if isinstance(targets, torch.Tensor):
|
||||
targets = targets.cpu().numpy()
|
||||
|
||||
# un-normalise
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255
|
||||
|
||||
tl = 3 # line thickness
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
|
||||
# Check if we should resize
|
||||
scale_factor = max_size / max(h, w)
|
||||
if scale_factor < 1:
|
||||
h = math.ceil(scale_factor * h)
|
||||
w = math.ceil(scale_factor * w)
|
||||
|
||||
colors = color_list() # list of colors
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i, img in enumerate(images):
|
||||
if i == max_subplots: # if last batch has fewer images than we expect
|
||||
break
|
||||
|
||||
block_x = int(w * (i // ns))
|
||||
block_y = int(h * (i % ns))
|
||||
|
||||
img = img.transpose(1, 2, 0)
|
||||
if scale_factor < 1:
|
||||
img = cv2.resize(img, (w, h))
|
||||
|
||||
mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
|
||||
if len(targets) > 0:
|
||||
image_targets = targets[targets[:, 0] == i]
|
||||
boxes = xywh2xyxy(image_targets[:, 2:6]).T
|
||||
classes = image_targets[:, 1].astype('int')
|
||||
labels = image_targets.shape[1] == 6 # labels if no conf column
|
||||
conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
|
||||
|
||||
if boxes.shape[1]:
|
||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
||||
boxes[[0, 2]] *= w # scale to pixels
|
||||
boxes[[1, 3]] *= h
|
||||
elif scale_factor < 1: # absolute coords need scale if image scales
|
||||
boxes *= scale_factor
|
||||
boxes[[0, 2]] += block_x
|
||||
boxes[[1, 3]] += block_y
|
||||
for j, box in enumerate(boxes.T):
|
||||
cls = int(classes[j])
|
||||
color = colors[cls % len(colors)]
|
||||
cls = names[cls] if names else cls
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
|
||||
plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
|
||||
|
||||
# Draw image filename labels
|
||||
if paths:
|
||||
label = Path(paths[i]).name[:40] # trim to 40 char
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
|
||||
# Image border
|
||||
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
|
||||
|
||||
if fname:
|
||||
r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
|
||||
mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
|
||||
# cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
|
||||
Image.fromarray(mosaic).save(fname) # PIL save
|
||||
return mosaic
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
|
||||
# Plot LR simulating training for full epochs
|
||||
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
|
||||
y = []
|
||||
for _ in range(epochs):
|
||||
scheduler.step()
|
||||
y.append(optimizer.param_groups[0]['lr'])
|
||||
plt.plot(y, '.-', label='LR')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('LR')
|
||||
plt.grid()
|
||||
plt.xlim(0, epochs)
|
||||
plt.ylim(0)
|
||||
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_test_txt(): # from utils.plots import *; plot_test()
|
||||
# Plot test.txt histograms
|
||||
x = np.loadtxt('test.txt', dtype=np.float32)
|
||||
box = xyxy2xywh(x[:, :4])
|
||||
cx, cy = box[:, 0], box[:, 1]
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
|
||||
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
|
||||
ax.set_aspect('equal')
|
||||
plt.savefig('hist2d.png', dpi=300)
|
||||
|
||||
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
|
||||
ax[0].hist(cx, bins=600)
|
||||
ax[1].hist(cy, bins=600)
|
||||
plt.savefig('hist1d.png', dpi=200)
|
||||
|
||||
|
||||
def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
|
||||
# Plot targets.txt histograms
|
||||
x = np.loadtxt('targets.txt', dtype=np.float32).T
|
||||
s = ['x targets', 'y targets', 'width targets', 'height targets']
|
||||
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
for i in range(4):
|
||||
ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
|
||||
ax[i].legend()
|
||||
ax[i].set_title(s[i])
|
||||
plt.savefig('targets.jpg', dpi=200)
|
||||
|
||||
|
||||
def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
|
||||
# Plot study.txt generated by test.py
|
||||
fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
|
||||
# ax = ax.ravel()
|
||||
|
||||
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
|
||||
# for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
|
||||
for f in sorted(Path(path).glob('study*.txt')):
|
||||
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
|
||||
x = np.arange(y.shape[1]) if x is None else np.array(x)
|
||||
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
|
||||
# for i in range(7):
|
||||
# ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
|
||||
# ax[i].set_title(s[i])
|
||||
|
||||
j = y[3].argmax() + 1
|
||||
ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
|
||||
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
|
||||
|
||||
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
|
||||
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
|
||||
|
||||
ax2.grid(alpha=0.2)
|
||||
ax2.set_yticks(np.arange(20, 60, 5))
|
||||
ax2.set_xlim(0, 57)
|
||||
ax2.set_ylim(30, 55)
|
||||
ax2.set_xlabel('GPU Speed (ms/img)')
|
||||
ax2.set_ylabel('COCO AP val')
|
||||
ax2.legend(loc='lower right')
|
||||
plt.savefig(str(Path(path).name) + '.png', dpi=300)
|
||||
|
||||
|
||||
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
||||
# plot dataset labels
|
||||
print('Plotting labels... ')
|
||||
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
||||
nc = int(c.max() + 1) # number of classes
|
||||
colors = color_list()
|
||||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
||||
|
||||
# seaborn correlogram
|
||||
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
plt.close()
|
||||
|
||||
# matplotlib labels
|
||||
matplotlib.use('svg') # faster
|
||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
ax[0].set_ylabel('instances')
|
||||
if 0 < len(names) < 30:
|
||||
ax[0].set_xticks(range(len(names)))
|
||||
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel('classes')
|
||||
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# rectangles
|
||||
labels[:, 1:3] = 0.5 # center
|
||||
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
|
||||
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
||||
for cls, *box in labels[:1000]:
|
||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
|
||||
ax[1].imshow(img)
|
||||
ax[1].axis('off')
|
||||
|
||||
for a in [0, 1, 2, 3]:
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
||||
matplotlib.use('Agg')
|
||||
plt.close()
|
||||
|
||||
# loggers
|
||||
for k, v in loggers.items() or {}:
|
||||
if k == 'wandb' and v:
|
||||
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
|
||||
|
||||
|
||||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
||||
# Plot hyperparameter evolution results in evolve.txt
|
||||
with open(yaml_file) as f:
|
||||
hyp = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
f = fitness(x)
|
||||
# weights = (f - f.min()) ** 2 # for weighted results
|
||||
plt.figure(figsize=(10, 12), tight_layout=True)
|
||||
matplotlib.rc('font', **{'size': 8})
|
||||
for i, (k, v) in enumerate(hyp.items()):
|
||||
y = x[:, i + 7]
|
||||
# mu = (y * weights).sum() / weights.sum() # best weighted result
|
||||
mu = y[f.argmax()] # best single result
|
||||
plt.subplot(6, 5, i + 1)
|
||||
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
|
||||
plt.plot(mu, f.max(), 'k+', markersize=15)
|
||||
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
|
||||
if i % 5 != 0:
|
||||
plt.yticks([])
|
||||
print('%15s: %.3g' % (k, mu))
|
||||
plt.savefig('evolve.png', dpi=200)
|
||||
print('\nPlot saved as evolve.png')
|
||||
|
||||
|
||||
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
|
||||
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
|
||||
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
|
||||
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
|
||||
files = list(Path(save_dir).glob('frames*.txt'))
|
||||
for fi, f in enumerate(files):
|
||||
try:
|
||||
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
|
||||
n = results.shape[1] # number of rows
|
||||
x = np.arange(start, min(stop, n) if stop else n)
|
||||
results = results[:, x]
|
||||
t = (results[0] - results[0].min()) # set t0=0s
|
||||
results[0] = x
|
||||
for i, a in enumerate(ax):
|
||||
if i < len(results):
|
||||
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
|
||||
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
|
||||
a.set_title(s[i])
|
||||
a.set_xlabel('time (s)')
|
||||
# if fi == len(files) - 1:
|
||||
# a.set_ylim(bottom=0)
|
||||
for side in ['top', 'right']:
|
||||
a.spines[side].set_visible(False)
|
||||
else:
|
||||
a.remove()
|
||||
except Exception as e:
|
||||
print('Warning: Plotting error for %s; %s' % (f, e))
|
||||
|
||||
ax[1].legend()
|
||||
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
|
||||
|
||||
|
||||
def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
|
||||
# Plot training 'results*.txt', overlaying train and val losses
|
||||
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
|
||||
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
|
||||
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
|
||||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
|
||||
n = results.shape[1] # number of rows
|
||||
x = range(start, min(stop, n) if stop else n)
|
||||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
for i in range(5):
|
||||
for j in [i, i + 5]:
|
||||
y = results[j, x]
|
||||
ax[i].plot(x, y, marker='.', label=s[j])
|
||||
# y_smooth = butter_lowpass_filtfilt(y)
|
||||
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
|
||||
|
||||
ax[i].set_title(t[i])
|
||||
ax[i].legend()
|
||||
ax[i].set_ylabel(f) if i == 0 else None # add filename
|
||||
fig.savefig(f.replace('.txt', '.png'), dpi=200)
|
||||
|
||||
|
||||
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
||||
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
|
||||
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
|
||||
if bucket:
|
||||
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
|
||||
files = ['results%g.txt' % x for x in id]
|
||||
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
|
||||
os.system(c)
|
||||
else:
|
||||
files = list(Path(save_dir).glob('results*.txt'))
|
||||
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
|
||||
for fi, f in enumerate(files):
|
||||
try:
|
||||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
|
||||
n = results.shape[1] # number of rows
|
||||
x = range(start, min(stop, n) if stop else n)
|
||||
for i in range(10):
|
||||
y = results[i, x]
|
||||
if i in [0, 1, 2, 5, 6, 7]:
|
||||
y[y == 0] = np.nan # don't show zero loss values
|
||||
# y /= y[0] # normalize
|
||||
label = labels[fi] if len(labels) else f.stem
|
||||
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
|
||||
ax[i].set_title(s[i])
|
||||
# if i in [5, 6, 7]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
print('Warning: Plotting error for %s; %s' % (f, e))
|
||||
|
||||
ax[1].legend()
|
||||
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
||||
|
|
@ -0,0 +1,454 @@
|
|||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from kafka.errors import kafka_errors
|
||||
import traceback
|
||||
import json, base64,os
|
||||
import numpy as np
|
||||
from multiprocessing import Process,Queue
|
||||
import time,cv2,string,random
|
||||
import subprocess as sp
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from util.datasets import LoadStreams, LoadImages
|
||||
from util.experimental import attempt_load
|
||||
from util.general import check_img_size, check_requirements, check_imshow, non_max_suppression,overlap_box_suppression, xyxy2xywh
|
||||
|
||||
import torch,sys
|
||||
|
||||
from DrGraph.util.drHelper import *
|
||||
|
||||
#from segutils.segmodel import SegModel,get_largest_contours
|
||||
#sys.path.extend(['../yolov5/segutils'])
|
||||
|
||||
from util.segutils.segWaterBuilding import SegModel,get_largest_contours,illBuildings
|
||||
|
||||
#from segutils.core.models.bisenet import BiSeNet
|
||||
from util.segutils.core.models.bisenet import BiSeNet_MultiOutput
|
||||
|
||||
from util.plots import plot_one_box,plot_one_box_PIL,draw_painting_joint,get_label_arrays,get_websource
|
||||
from collections import Counter
|
||||
#import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
# get_labelnames,get_label_arrays,post_process_,save_problem_images,time_str
|
||||
#FP_DEBUG=open('debut.txt','w')
|
||||
def bsJpgCode(image_ori):
|
||||
jpgCode = cv2.imencode('.jpg',image_ori)[-1]###np.array,(4502009,1)
|
||||
bsCode = str(base64.b64encode(jpgCode))[2:-1] ###str,长6002680
|
||||
return bsCode
|
||||
def bsJpgDecode(bsCode):
|
||||
bsDecode = base64.b64decode(bsCode)###types,长4502009
|
||||
npString = np.frombuffer(bsDecode,np.uint8)###np.array,(长4502009,)
|
||||
jpgDecode = cv2.imdecode(npString,cv2.IMREAD_COLOR)###np.array,(3000,4000,3)
|
||||
return jpgDecode
|
||||
|
||||
rainbows=[
|
||||
(0,0,255),(0,255,0),(255,0,0),(255,0,255),(255,255,0),(255,127,0),(255,0,127),
|
||||
(127,255,0),(0,255,127),(0,127,255),(127,0,255),(255,127,255),(255,255,127),
|
||||
(127,255,255),(0,255,255),(255,127,255),(127,255,255),
|
||||
(0,127,0),(0,0,127),(0,255,255)
|
||||
]
|
||||
|
||||
def check_stream(stream):
|
||||
cap = cv2.VideoCapture(stream)
|
||||
if cap.isOpened():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
#####
|
||||
def drawWater(pred,image_array0,river={'color':(0,255,255),'line_width':3,'segRegionCnt':2,'segLineShow':True}):####pred是模型的输出,只有水分割的任务
|
||||
##画出水体区域
|
||||
contours, hierarchy = cv2.findContours(pred,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
water = pred.copy(); water[:,:] = 0
|
||||
|
||||
if len(contours)==0:
|
||||
return image_array0,water
|
||||
max_ids = get_largest_contours(contours,river['segRegionCnt']);
|
||||
for max_id in max_ids:
|
||||
cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
|
||||
if river['segLineShow']:
|
||||
cv2.drawContours(image_array0,contours,max_id,river['color'],river['line_width'] )
|
||||
return image_array0,water
|
||||
|
||||
|
||||
def post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe,ObjectPar={ 'object_config':[0,1,2,3,4], 'slopeIndex':[5,6,7] ,'segmodel':True,'segRegionCnt':1 },font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3},padInfos=None ,ovlap_thres=None):
|
||||
object_config,slopeIndex,segmodel,segRegionCnt=ObjectPar['object_config'],ObjectPar['slopeIndex'],ObjectPar['segmodel'],ObjectPar['segRegionCnt']
|
||||
##输入dataset genereate 生成的数据,model预测的结果pred,nms参数
|
||||
##主要操作NMS ---> 坐标转换 ---> 画图
|
||||
##输出原图、AI处理后的图、检测结果
|
||||
time0=time.time()
|
||||
path, img, im0s, vid_cap ,pred,seg_pred= datas[0:6];
|
||||
#segmodel=True
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
|
||||
if ovlap_thres:
|
||||
pred = overlap_box_suppression(pred, ovlap_thres)
|
||||
time1=time.time()
|
||||
i=0;det=pred[0]###一次检测一张图片
|
||||
time1_1 = time.time()
|
||||
#p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
|
||||
p, s, im0 = path[i], '%g: ' % i, im0s[i]
|
||||
time1_2 = time.time()
|
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
time1_3 = time.time()
|
||||
det_xywh=[];
|
||||
#im0_brg=cv2.cvtColor(im0,cv2.COLOR_RGB2BGR);
|
||||
if segmodel:
|
||||
if len(seg_pred)==2:
|
||||
im0,water = illBuildings(seg_pred,im0)
|
||||
else:
|
||||
river={ 'color':font['waterLineColor'],'line_width':font['waterLineWidth'],'segRegionCnt':segRegionCnt,'segLineShow':font['segLineShow'] }
|
||||
im0,water = drawWater(seg_pred,im0,river)
|
||||
time2=time.time()
|
||||
#plt.imshow(im0);plt.show()
|
||||
if len(det)>0:
|
||||
# Rescale boxes from img_size to im0 size
|
||||
if not padInfos:
|
||||
det[:, :4] = imgHelper.scale_coords(img.shape[2:], det[:, :4],im0.shape).round()
|
||||
else:
|
||||
#print('####line131:',det[:, :])
|
||||
det[:, :4] = imgHelper.scale_back( det[:, :4],padInfos).round()
|
||||
#print('####line133:',det[:, :])
|
||||
#用seg模型,确定有效检测匡及河道轮廓线
|
||||
if segmodel:
|
||||
cls_indexs = det[:, 5].clone().cpu().numpy().astype(np.int32)
|
||||
##判断哪些目标属于岸坡的
|
||||
slope_flag = np.array([x in slopeIndex for x in cls_indexs ] )
|
||||
|
||||
det_c = det.clone(); det_c=det_c.cpu().numpy()
|
||||
try:
|
||||
area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )*1.0/(1.0*(x[2]-x[0])*(x[3]-x[1])+0.00001) for x in det_c] )
|
||||
except:
|
||||
print('*****************************line143: error:',det_c)
|
||||
water_flag = np.array(area_factors>0.1)
|
||||
det = det[water_flag|slope_flag]##如果是水上目标,则需要与水的iou超过0.1;如果是岸坡目标,则直接保留。
|
||||
#对检测匡绘图
|
||||
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
cls_c = cls.cpu().numpy()
|
||||
|
||||
|
||||
conf_c = conf.cpu().numpy()
|
||||
tt=[ int(x.cpu()) for x in xyxy]
|
||||
#line = [float(cls_c), *tt, float(conf_c)] # label format
|
||||
line = [*tt, float(conf_c), float(cls_c)] # label format
|
||||
det_xywh.append(line)
|
||||
label = f'{names[int(cls)]} {conf:.2f}'
|
||||
#print('- '*20, ' line165:',xyxy,cls,conf )
|
||||
if int(cls_c) not in object_config: ###如果不是所需要的目标,则不显示
|
||||
continue
|
||||
#print('- '*20, ' line168:',xyxy,cls,conf )
|
||||
im0 = drawHelper.draw_painting_joint(xyxy,im0,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],font=font)
|
||||
time3=time.time()
|
||||
strout='nms:%s drawWater:%s,copy:%s,toTensor:%s,detDraw:%s '%( \
|
||||
timeHelper.deltaTime_MS(time0,time1),\
|
||||
timeHelper.deltaTime_MS(time1,time2),\
|
||||
timeHelper.deltaTime_MS(time1_1,time1_2),\
|
||||
timeHelper.deltaTime_MS(time1_2,time1_3), \
|
||||
timeHelper.deltaTime_MS(time2,time3) )
|
||||
return [im0s[0],im0,det_xywh,iframe],strout
|
||||
|
||||
|
||||
def getDetections(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe,ObjectPar={ 'object_config':[0,1,2,3,4], 'slopeIndex':[5,6,7] ,'segmodel':True,'segRegionCnt':1 },font={ 'line_thickness':None, 'fontSize':None,'boxLine_thickness':None,'waterLineColor':(0,255,255),'waterLineWidth':3},padInfos=None ,ovlap_thres=None):
|
||||
object_config,slopeIndex,segmodel,segRegionCnt=ObjectPar['object_config'],ObjectPar['slopeIndex'],ObjectPar['segmodel'],ObjectPar['segRegionCnt']
|
||||
##输入dataset genereate 生成的数据,model预测的结果pred,nms参数
|
||||
##主要操作NMS ---> 坐标转换 ---> 画图
|
||||
##输出原图、AI处理后的图、检测结果
|
||||
time0=time.time()
|
||||
path, img, im0s, vid_cap ,pred,seg_pred= datas[0:6];
|
||||
#segmodel=True
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
|
||||
if ovlap_thres:
|
||||
pred = overlap_box_suppression(pred, ovlap_thres)
|
||||
time1=time.time()
|
||||
i=0;det=pred[0]###一次检测一张图片
|
||||
time1_1 = time.time()
|
||||
#p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
|
||||
p, s, im0 = path[i], '%g: ' % i, im0s[i]
|
||||
time1_2 = time.time()
|
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
time1_3 = time.time()
|
||||
det_xywh=[];
|
||||
#im0_brg=cv2.cvtColor(im0,cv2.COLOR_RGB2BGR);
|
||||
if segmodel:
|
||||
if len(seg_pred)==2:
|
||||
im0,water = illBuildings(seg_pred,im0)
|
||||
else:
|
||||
river={ 'color':font['waterLineColor'],'line_width':font['waterLineWidth'],'segRegionCnt':segRegionCnt,'segLineShow':font['segLineShow'] }
|
||||
im0,water = drawWater(seg_pred,im0,river)
|
||||
time2=time.time()
|
||||
#plt.imshow(im0);plt.show()
|
||||
if len(det)>0:
|
||||
# Rescale boxes from img_size to im0 size
|
||||
if not padInfos:
|
||||
det[:, :4] = imgHelper.scale_coords(img.shape[2:], det[:, :4],im0.shape).round()
|
||||
else:
|
||||
#print('####line131:',det[:, :])
|
||||
det[:, :4] = imgHelper.scale_back( det[:, :4],padInfos).round()
|
||||
#print('####line133:',det[:, :])
|
||||
#用seg模型,确定有效检测匡及河道轮廓线
|
||||
if segmodel:
|
||||
cls_indexs = det[:, 5].clone().cpu().numpy().astype(np.int32)
|
||||
##判断哪些目标属于岸坡的
|
||||
slope_flag = np.array([x in slopeIndex for x in cls_indexs ] )
|
||||
det_c = det.clone(); det_c=det_c.cpu().numpy()
|
||||
try:
|
||||
area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )*1.0/(1.0*(x[2]-x[0])*(x[3]-x[1])+0.00001) for x in det_c] )
|
||||
except:
|
||||
print('*****************************line143: error:',det_c)
|
||||
water_flag = np.array(area_factors>0.1)
|
||||
det = det[water_flag|slope_flag]##如果是水上目标,则需要与水的iou超过0.1;如果是岸坡目标,则直接保留。
|
||||
#对检测匡绘图
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
cls_c = cls.cpu().numpy()
|
||||
|
||||
|
||||
conf_c = conf.cpu().numpy()
|
||||
tt=[ int(x.cpu()) for x in xyxy]
|
||||
line = [float(cls_c), *tt, float(conf_c)] # label format
|
||||
det_xywh.append(line)
|
||||
label = f'{names[int(cls)]} {conf:.2f}'
|
||||
if int(cls_c) not in object_config: ###如果不是所需要的目标,则不显示
|
||||
continue
|
||||
|
||||
time3=time.time()
|
||||
strout='nms:%s drawWater:%s,copy:%s,toTensor:%s,detDraw:%s '%(\
|
||||
timeHelper.deltaTime_MS(time0,time1),\
|
||||
timeHelper.deltaTime_MS(time1,time2),\
|
||||
timeHelper.deltaTime_MS(time1_1,time1_2),\
|
||||
timeHelper.deltaTime_MS(time1_2,time1_3), \
|
||||
timeHelper.deltaTime_MS(time2,time3) )
|
||||
return [im0s[0],im0,det_xywh,iframe],strout
|
||||
|
||||
|
||||
def riverDetSegMixProcess(preds,water,pars={'slopeIndex':list(range(20)),'riverIou':0.1}):
|
||||
'''
|
||||
输入参数:
|
||||
preds:二维的list,之前的检测结果,格式,[cls,x0,y0,x1,y1,score]
|
||||
water:二维数据,值是0,1。1--表示水域,0--表示背景。
|
||||
im0: 原始没有
|
||||
pars:出去preds,water之外的参数,dict形式
|
||||
slopeIndex:岸坡上目标类别索引
|
||||
threshold:水里的目标,与水域重合的比例阈值
|
||||
输出参数:
|
||||
det:检测结果
|
||||
'''
|
||||
assert 'slopeIndex' in pars.keys(), 'input para keys error,No: slopeIndex'
|
||||
assert 'riverIou' in pars.keys(), 'input para keys error, No: riverIou'
|
||||
time0 = time.time()
|
||||
slopeIndex,riverIou = pars['slopeIndex'],pars['riverIou']
|
||||
if len(preds)>0:
|
||||
preds = np.array(preds)
|
||||
cls_indexs = [int(x[5]) for x in preds]
|
||||
#area_factors= np.array([np.sum(water[int(x[2]):int(x[4]), int(x[1]):int(x[3])] )*1.0/(1.0*(x[3]-x[1])*(x[4]-x[2])+0.00001) for x in preds] )
|
||||
area_factors= np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )*1.0/(1.0*(x[2]-x[0])*(x[3]-x[1])+0.00001) for x in preds] )
|
||||
slope_flag = np.array([x in slopeIndex for x in cls_indexs ] )
|
||||
water_flag = np.array(area_factors>riverIou)
|
||||
det = preds[water_flag|slope_flag]##如果是水上目标,则需要与水的iou超过0.1;如果是岸坡目标,则直接保留。
|
||||
else: det=[]
|
||||
#print('##'*20,det)
|
||||
time1=time.time()
|
||||
timeInfos = 'all: %.1f '%( (time1-time0) )
|
||||
return det ,timeInfos
|
||||
def riverDetSegMixProcess_N(predList,pars={'slopeIndex':list(range(20)),'riverIou':0.1}):
|
||||
preds, water = predList[0:2]
|
||||
return riverDetSegMixProcess(preds,water,pars=pars)
|
||||
|
||||
|
||||
def detectDraw(im0,dets,label_arraylist,rainbows,font):
|
||||
for det in dets:
|
||||
xyxy = det[1:5]
|
||||
cls = det[0];
|
||||
conf = det[5]
|
||||
im0 = drawHelper.draw_painting_joint(xyxy,im0,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],font=font)
|
||||
return im0
|
||||
|
||||
|
||||
def preprocess(par):
|
||||
print('#####process:',par['name'])
|
||||
##负责读取视频,生成原图及供检测的使用图,numpy格式
|
||||
#source='rtmp://liveplay.yunhengzhizao.cn/live/demo_HD5M'
|
||||
#img_size=640; stride=32
|
||||
while True:
|
||||
cap = cv2.VideoCapture(par['source'])
|
||||
iframe = 0
|
||||
if cap.isOpened():
|
||||
print( '#### read %s success!'%(par['source']))
|
||||
try:
|
||||
dataset = LoadStreams(par['source'], img_size=640, stride=32)
|
||||
for path, img, im0s, vid_cap in dataset:
|
||||
datas=[path, img, im0s, vid_cap,iframe]
|
||||
par['queOut'].put(datas)
|
||||
iframe +=1
|
||||
except Exception as e:
|
||||
print('###read error:%s '%(par['source']))
|
||||
time.sleep(10)
|
||||
iframe = 0
|
||||
|
||||
else:
|
||||
print('###read error:%s '%(par['source'] ))
|
||||
time.sleep(10)
|
||||
iframe = 0
|
||||
|
||||
def gpu_process(par):
|
||||
print('#####process:',par['name'])
|
||||
half=True
|
||||
##gpu运算,检测模型
|
||||
weights = par['weights']
|
||||
device = par['device']
|
||||
print('###line127:',par['device'])
|
||||
model = attempt_load(par['weights'], map_location=par['device']) # load FP32 model
|
||||
if half:
|
||||
model.half()
|
||||
|
||||
##gpu运算,分割模型
|
||||
seg_nclass = par['seg_nclass']
|
||||
seg_weights = par['seg_weights']
|
||||
|
||||
#segmodel = SegModel(nclass=seg_nclass,weights=seg_weights,device=device)
|
||||
|
||||
|
||||
nclass = [2,2]
|
||||
Segmodel = BiSeNet_MultiOutput(nclass)
|
||||
weights='weights/segmentation/WaterBuilding.pth'
|
||||
segmodel = SegModel(model=Segmodel,nclass=nclass,weights=weights,device='cuda:0',multiOutput=True)
|
||||
while True:
|
||||
if not par['queIn'].empty():
|
||||
time0=time.time()
|
||||
datas = par['queIn'].get()
|
||||
path, img, im0s, vid_cap,iframe = datas[0:5]
|
||||
time1=time.time()
|
||||
img = torch.from_numpy(img).to(device)
|
||||
img = img.half() if half else img.float() # uint8 to fp16/32
|
||||
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
time2 = time.time()
|
||||
pred = model(img,augment=False)[0]
|
||||
time3 = time.time()
|
||||
seg_pred = segmodel.eval(im0s[0],outsize=None,smooth_kernel=20)
|
||||
time4 = time.time()
|
||||
fpStr= 'process:%s ,iframe:%d,getdata:%s,copygpu:%s,dettime:%s,segtime:%s , time:%s, queLen:%d '%( par['name'],iframe,\
|
||||
timeHelper.deltaTime_MS(time0,time1) ,\
|
||||
timeHelper.deltaTime_MS(time1,time2) ,\
|
||||
timeHelper.deltaTime_MS(time2,time3) ,\
|
||||
timeHelper.deltaTime_MS(time3,time4),\
|
||||
timeHelper.deltaTime_MS(time0,time4) , \
|
||||
par['queIn'].qsize() )
|
||||
#FP_DEBUG.write( fpStr+'\n' )
|
||||
datasOut = [path, img, im0s, vid_cap,pred,seg_pred,iframe]
|
||||
par['queOut'].put(datasOut)
|
||||
if par['debug']:
|
||||
print('#####process:',par['name'],' line107')
|
||||
else:
|
||||
time.sleep(1/300)
|
||||
def get_cls(array):
|
||||
dcs = Counter(array)
|
||||
keys = list(dcs.keys())
|
||||
values = list(dcs.values())
|
||||
max_index = values.index(max(values))
|
||||
cls = int(keys[max_index])
|
||||
return cls
|
||||
def save_problem_images(post_results,iimage_cnt,names,streamName='live-THSAHD5M',outImaDir='problems/images_tmp',imageTxtFile=False):
|
||||
## [cls, x,y,w,h, conf]
|
||||
problem_image=[[] for i in range(6)]
|
||||
|
||||
|
||||
dets_list = [x[2] for x in post_results]
|
||||
|
||||
mean_scores=[ np.array(x)[:,5].mean() for x in dets_list ] ###mean conf
|
||||
|
||||
best_index = mean_scores.index(max(mean_scores)) ##获取该批图片里,问题图片的index
|
||||
best_frame = post_results[ best_index][3] ##获取绝对帧号
|
||||
img_send = post_results[best_index][1]##AI处理后的图
|
||||
img_bak = post_results[best_index][0]##原图
|
||||
cls_max = get_cls( x[5] for x in dets_list[best_index] )
|
||||
|
||||
|
||||
time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
uid=''.join(random.sample(string.ascii_letters + string.digits, 16))
|
||||
#ori_name = '2022-01-20-15-57-36_frame-368-720_type-漂浮物_qVh4zI08ZlwJN9on_s-live-THSAHD5M_OR.jpg'
|
||||
#2022-01-13-15-07-57_frame-9999-9999_type-结束_9999999999999999_s-off-XJRW20220110115904_AI.jpg
|
||||
outnameOR= '%s/%s_frame-%d-%d_type-%s_%s_s-%s_AI.jpg'%(outImaDir,time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
|
||||
outnameAR= '%s/%s_frame-%d-%d_type-%s_%s_s-%s_OR.jpg'%(outImaDir,time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
|
||||
|
||||
cv2.imwrite(outnameOR,img_send)
|
||||
try:
|
||||
cv2.imwrite(outnameAR,img_bak)
|
||||
except:
|
||||
print(outnameAR,type(img_bak),img_bak.size())
|
||||
if imageTxtFile:
|
||||
outnameOR_txt = outnameOR.replace('.jpg','.txt')
|
||||
fp=open(outnameOR_txt,'w');fp.write(outnameOR+'\n');fp.close()
|
||||
outnameAI_txt = outnameAR.replace('.jpg','.txt')
|
||||
fp=open(outnameAI_txt,'w');fp.write(outnameAR+'\n');fp.close()
|
||||
|
||||
parOut = {}; parOut['imgOR'] = img_send; parOut['imgAR'] = img_send; parOut['uid']=uid
|
||||
parOut['imgORname']=os.path.basename(outnameOR);parOut['imgARname']=os.path.basename(outnameAR);
|
||||
parOut['time_str'] = time_str;parOut['type'] = names[cls_max]
|
||||
return parOut
|
||||
|
||||
|
||||
|
||||
|
||||
def post_process(par):
|
||||
|
||||
print('#####process:',par['name'])
|
||||
###post-process参数
|
||||
conf_thres,iou_thres,classes=par['conf_thres'],par['iou_thres'],par['classes']
|
||||
labelnames=par['labelnames']
|
||||
rainbows=par['rainbows']
|
||||
fpsample = par['fpsample']
|
||||
names = ioHelper.get_labelnames(labelnames)
|
||||
label_arraylist = get_label_arrays(names,rainbows,outfontsize=40)
|
||||
iimage_cnt = 0
|
||||
post_results=[]
|
||||
while True:
|
||||
if not par['queIn'].empty():
|
||||
time0=time.time()
|
||||
datas = par['queIn'].get()
|
||||
iframe = datas[6]
|
||||
if par['debug']:
|
||||
print('#####process:',par['name'],' line129')
|
||||
p_result,timeOut = post_process_(datas,conf_thres, iou_thres,names,label_arraylist,rainbows,iframe)
|
||||
par['queOut'].put(p_result)
|
||||
##输出结果
|
||||
|
||||
|
||||
|
||||
##每隔 fpsample帧处理一次,如果有问题就保存图片
|
||||
if (iframe % fpsample == 0) and (len(post_results)>0) :
|
||||
#print('####line204:',iframe,post_results)
|
||||
save_problem_images(post_results,iframe,names)
|
||||
post_results=[]
|
||||
|
||||
if len(p_result[2] )>0: ##
|
||||
#post_list = p_result.append(iframe)
|
||||
post_results.append(p_result)
|
||||
#print('####line201:',type(p_result))
|
||||
|
||||
time1=time.time()
|
||||
outstr='process:%s ,iframe:%d,%s , time:%s, queLen:%d '%( par['name'],iframe,timeOut,\
|
||||
timeHelper.deltaTime_MS(time0,time1) ,par['queIn'].qsize() )
|
||||
#FP_DEBUG.write(outstr +'\n')
|
||||
#print( 'process:%s ,iframe:%d,%s , time:%s, queLen:%d '%( par['name'],iframe,timeOut,timeHelper.deltaTime_MS(time0,time1) ,par['queIn'].qsize() ) )
|
||||
else:
|
||||
time.sleep(1/300)
|
||||
|
||||
|
||||
def save_logfile(name,txt):
|
||||
if os.path.exists(name):
|
||||
fp=open(name,'r+')
|
||||
else:
|
||||
fp=open(name,'w')
|
||||
|
||||
fp.write('%s %s \n'%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),txt))
|
||||
fp.close()
|
||||
def time_str():
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
jsonfile='config/queRiver.json'
|
||||
#image_encode_decode()
|
||||
work_stream(jsonfile)
|
||||
#par={'name':'preprocess'}
|
||||
#preprocess(par)
|
||||
|
|
@ -0,0 +1,501 @@
|
|||
#@@ -1,43 +1,43 @@
|
||||
# GPUtil - GPU utilization
|
||||
#
|
||||
# A Python module for programmically getting the GPU utilization from NVIDA GPUs using nvidia-smi
|
||||
#
|
||||
# Author: Anders Krogh Mortensen (anderskm)
|
||||
# Date: 16 January 2017
|
||||
# Web: https://github.com/anderskm/gputil
|
||||
#
|
||||
# LICENSE
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 anderskm
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from subprocess import Popen, PIPE
|
||||
from distutils import spawn
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
|
||||
__version__ = '1.4.0'
|
||||
class GPU:
|
||||
def __init__(self, ID, uuid, load, memoryTotal, memoryUsed, memoryFree, driver, gpu_name, serial, display_mode, display_active, temp_gpu):
|
||||
self.id = ID
|
||||
self.uuid = uuid
|
||||
self.load = load
|
||||
self.memoryUtil = float(memoryUsed)/float(memoryTotal)
|
||||
self.memoryTotal = memoryTotal
|
||||
self.memoryUsed = memoryUsed
|
||||
self.memoryFree = memoryFree
|
||||
self.driver = driver
|
||||
self.name = gpu_name
|
||||
self.serial = serial
|
||||
self.display_mode = display_mode
|
||||
self.display_active = display_active
|
||||
self.temperature = temp_gpu
|
||||
|
||||
def __str__(self):
|
||||
return str(self.__dict__)
|
||||
|
||||
|
||||
class GPUProcess:
|
||||
def __init__(self, pid, processName, gpuId, gpuUuid, gpuName, usedMemory,
|
||||
uid, uname):
|
||||
self.pid = pid
|
||||
self.processName = processName
|
||||
self.gpuId = gpuId
|
||||
self.gpuUuid = gpuUuid
|
||||
self.gpuName = gpuName
|
||||
self.usedMemory = usedMemory
|
||||
self.uid = uid
|
||||
self.uname = uname
|
||||
|
||||
def __str__(self):
|
||||
return str(self.__dict__)
|
||||
|
||||
def safeFloatCast(strNumber):
|
||||
try:
|
||||
number = float(strNumber)
|
||||
except ValueError:
|
||||
number = float('nan')
|
||||
return number
|
||||
|
||||
#def getGPUs():
|
||||
def getNvidiaSmiCmd():
|
||||
if platform.system() == "Windows":
|
||||
# If the platform is Windows and nvidia-smi
|
||||
# could not be found from the environment path,
|
||||
#@@ -75,57 +94,97 @@ def getGPUs():
|
||||
nvidia_smi = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ['systemdrive']
|
||||
else:
|
||||
nvidia_smi = "nvidia-smi"
|
||||
return nvidia_smi
|
||||
|
||||
|
||||
def getGPUs():
|
||||
# Get ID, processing and memory utilization for all GPUs
|
||||
nvidia_smi = getNvidiaSmiCmd()
|
||||
try:
|
||||
p = Popen([nvidia_smi,"--query-gpu=index,uuid,utilization.gpu,memory.total,memory.used,memory.free,driver_version,name,gpu_serial,display_active,display_mode,temperature.gpu", "--format=csv,noheader,nounits"], stdout=PIPE)
|
||||
stdout, stderror = p.communicate()
|
||||
p = subprocess.run([
|
||||
nvidia_smi,
|
||||
"--query-gpu=index,uuid,utilization.gpu,memory.total,memory.used,memory.free,driver_version,name,gpu_serial,display_active,display_mode,temperature.gpu",
|
||||
"--format=csv,noheader,nounits"
|
||||
], stdout=subprocess.PIPE, encoding='utf8')
|
||||
stdout, stderror = p.stdout, p.stderr
|
||||
except:
|
||||
return []
|
||||
output = stdout;#output = stdout.decode('UTF-8')
|
||||
# output = output[2:-1] # Remove b' and ' from string added by python
|
||||
#print(output)
|
||||
output = stdout
|
||||
## Parse output
|
||||
# Split on line break
|
||||
lines = output.split(os.linesep)
|
||||
#print(lines)
|
||||
numDevices = len(lines)-1
|
||||
GPUs = []
|
||||
for g in range(numDevices):
|
||||
line = lines[g]
|
||||
#print(line)
|
||||
vals = line.split(', ')
|
||||
#print(vals)
|
||||
for i in range(12):
|
||||
# print(vals[i])
|
||||
if (i == 0):
|
||||
deviceIds = int(vals[i])
|
||||
elif (i == 1):
|
||||
uuid = vals[i]
|
||||
elif (i == 2):
|
||||
gpuUtil = safeFloatCast(vals[i])/100
|
||||
elif (i == 3):
|
||||
memTotal = safeFloatCast(vals[i])
|
||||
elif (i == 4):
|
||||
memUsed = safeFloatCast(vals[i])
|
||||
elif (i == 5):
|
||||
memFree = safeFloatCast(vals[i])
|
||||
elif (i == 6):
|
||||
driver = vals[i]
|
||||
elif (i == 7):
|
||||
gpu_name = vals[i]
|
||||
elif (i == 8):
|
||||
serial = vals[i]
|
||||
elif (i == 9):
|
||||
display_active = vals[i]
|
||||
elif (i == 10):
|
||||
display_mode = vals[i]
|
||||
elif (i == 11):
|
||||
temp_gpu = safeFloatCast(vals[i]);
|
||||
deviceIds = int(vals[0])
|
||||
uuid = vals[1]
|
||||
gpuUtil = safeFloatCast(vals[2]) / 100
|
||||
memTotal = safeFloatCast(vals[3])
|
||||
memUsed = safeFloatCast(vals[4])
|
||||
memFree = safeFloatCast(vals[5])
|
||||
driver = vals[6]
|
||||
gpu_name = vals[7]
|
||||
serial = vals[8]
|
||||
display_active = vals[9]
|
||||
display_mode = vals[10]
|
||||
temp_gpu = safeFloatCast(vals[11]);
|
||||
GPUs.append(GPU(deviceIds, uuid, gpuUtil, memTotal, memUsed, memFree, driver, gpu_name, serial, display_mode, display_active, temp_gpu))
|
||||
return GPUs # (deviceIds, gpuUtil, memUtil)
|
||||
|
||||
|
||||
|
||||
def getGPUProcesses():
|
||||
"""Get all gpu compute processes."""
|
||||
|
||||
global gpuUuidToIdMap
|
||||
gpuUuidToIdMap = {}
|
||||
try:
|
||||
gpus = getGPUs()
|
||||
for gpu in gpus:
|
||||
gpuUuidToIdMap[gpu.uuid] = gpu.id
|
||||
del gpus
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
nvidia_smi = getNvidiaSmiCmd()
|
||||
try:
|
||||
p = subprocess.run([
|
||||
nvidia_smi,
|
||||
"--query-compute-apps=pid,process_name,gpu_uuid,gpu_name,used_memory",
|
||||
"--format=csv,noheader,nounits"
|
||||
], stdout=subprocess.PIPE, encoding='utf8')
|
||||
stdout, stderror = p.stdout, p.stderr
|
||||
except:
|
||||
return []
|
||||
output = stdout
|
||||
## Parse output
|
||||
# Split on line break
|
||||
lines = output.split(os.linesep)
|
||||
numProcesses = len(lines) - 1
|
||||
processes = []
|
||||
for g in range(numProcesses):
|
||||
line = lines[g]
|
||||
#print(line)
|
||||
vals = line.split(', ')
|
||||
#print(vals)
|
||||
pid = int(vals[0])
|
||||
processName = vals[1]
|
||||
gpuUuid = vals[2]
|
||||
gpuName = vals[3]
|
||||
usedMemory = safeFloatCast(vals[4])
|
||||
gpuId = gpuUuidToIdMap[gpuUuid]
|
||||
if gpuId is None:
|
||||
gpuId = -1
|
||||
|
||||
# get uid and uname owner of the pid
|
||||
try:
|
||||
p = subprocess.run(['ps', f'-p{pid}', '-oruid=,ruser='],
|
||||
stdout=subprocess.PIPE, encoding='utf8')
|
||||
uid, uname = p.stdout.split()
|
||||
uid = int(uid)
|
||||
except:
|
||||
uid, uname = -1, ''
|
||||
|
||||
processes.append(GPUProcess(pid, processName, gpuId, gpuUuid,
|
||||
gpuName, usedMemory, uid, uname))
|
||||
return processes
|
||||
|
||||
|
||||
def getAvailable(order = 'first', limit=1, maxLoad=0.5, maxMemory=0.5, memoryFree=0, includeNan=False, excludeID=[], excludeUUID=[]):
|
||||
# order = first | last | random | load | memory
|
||||
# first --> select the GPU with the lowest ID (DEFAULT)
|
||||
# last --> select the GPU with the highest ID
|
||||
# random --> select a random available GPU
|
||||
# load --> select the GPU with the lowest load
|
||||
# memory --> select the GPU with the most memory available
|
||||
# limit = 1 (DEFAULT), 2, ..., Inf
|
||||
# Limit sets the upper limit for the number of GPUs to return. E.g. if limit = 2, but only one is available, only one is returned.
|
||||
# Get device IDs, load and memory usage
|
||||
GPUs = getGPUs()
|
||||
# Determine, which GPUs are available
|
||||
GPUavailability = getAvailability(GPUs, maxLoad=maxLoad, maxMemory=maxMemory, memoryFree=memoryFree, includeNan=includeNan, excludeID=excludeID, excludeUUID=excludeUUID)
|
||||
availAbleGPUindex = [idx for idx in range(0,len(GPUavailability)) if (GPUavailability[idx] == 1)]
|
||||
# Discard unavailable GPUs
|
||||
GPUs = [GPUs[g] for g in availAbleGPUindex]
|
||||
# Sort available GPUs according to the order argument
|
||||
if (order == 'first'):
|
||||
GPUs.sort(key=lambda x: float('inf') if math.isnan(x.id) else x.id, reverse=False)
|
||||
elif (order == 'last'):
|
||||
GPUs.sort(key=lambda x: float('-inf') if math.isnan(x.id) else x.id, reverse=True)
|
||||
elif (order == 'random'):
|
||||
GPUs = [GPUs[g] for g in random.sample(range(0,len(GPUs)),len(GPUs))]
|
||||
elif (order == 'load'):
|
||||
GPUs.sort(key=lambda x: float('inf') if math.isnan(x.load) else x.load, reverse=False)
|
||||
elif (order == 'memory'):
|
||||
GPUs.sort(key=lambda x: float('inf') if math.isnan(x.memoryUtil) else x.memoryUtil, reverse=False)
|
||||
# Extract the number of desired GPUs, but limited to the total number of available GPUs
|
||||
GPUs = GPUs[0:min(limit, len(GPUs))]
|
||||
# Extract the device IDs from the GPUs and return them
|
||||
deviceIds = [gpu.id for gpu in GPUs]
|
||||
return deviceIds
|
||||
#def getAvailability(GPUs, maxLoad = 0.5, maxMemory = 0.5, includeNan = False):
|
||||
# # Determine, which GPUs are available
|
||||
# GPUavailability = np.zeros(len(GPUs))
|
||||
# for i in range(len(GPUs)):
|
||||
# if (GPUs[i].load < maxLoad or (includeNan and np.isnan(GPUs[i].load))) and (GPUs[i].memoryUtil < maxMemory or (includeNan and np.isnan(GPUs[i].memoryUtil))):
|
||||
# GPUavailability[i] = 1
|
||||
def getAvailability(GPUs, maxLoad=0.5, maxMemory=0.5, memoryFree=0, includeNan=False, excludeID=[], excludeUUID=[]):
|
||||
# Determine, which GPUs are available
|
||||
GPUavailability = [1 if (gpu.memoryFree>=memoryFree) and (gpu.load < maxLoad or (includeNan and math.isnan(gpu.load))) and (gpu.memoryUtil < maxMemory or (includeNan and math.isnan(gpu.memoryUtil))) and ((gpu.id not in excludeID) and (gpu.uuid not in excludeUUID)) else 0 for gpu in GPUs]
|
||||
return GPUavailability
|
||||
def getFirstAvailable(order = 'first', maxLoad=0.5, maxMemory=0.5, attempts=1, interval=900, verbose=False, includeNan=False, excludeID=[], excludeUUID=[]):
|
||||
#GPUs = getGPUs()
|
||||
#firstAvailableGPU = np.NaN
|
||||
#for i in range(len(GPUs)):
|
||||
# if (GPUs[i].load < maxLoad) & (GPUs[i].memory < maxMemory):
|
||||
# firstAvailableGPU = GPUs[i].id
|
||||
# break
|
||||
#return firstAvailableGPU
|
||||
for i in range(attempts):
|
||||
if (verbose):
|
||||
print('Attempting (' + str(i+1) + '/' + str(attempts) + ') to locate available GPU.')
|
||||
# Get first available GPU
|
||||
available = getAvailable(order=order, limit=1, maxLoad=maxLoad, maxMemory=maxMemory, includeNan=includeNan, excludeID=excludeID, excludeUUID=excludeUUID)
|
||||
# If an available GPU was found, break for loop.
|
||||
if (available):
|
||||
if (verbose):
|
||||
print('GPU ' + str(available) + ' located!')
|
||||
break
|
||||
# If this is not the last attempt, sleep for 'interval' seconds
|
||||
if (i != attempts-1):
|
||||
time.sleep(interval)
|
||||
# Check if an GPU was found, or if the attempts simply ran out. Throw error, if no GPU was found
|
||||
if (not(available)):
|
||||
raise RuntimeError('Could not find an available GPU after ' + str(attempts) + ' attempts with ' + str(interval) + ' seconds interval.')
|
||||
# Return found GPU
|
||||
return available
|
||||
def showUtilization(all=False, attrList=None, useOldCode=False):
|
||||
GPUs = getGPUs()
|
||||
if (all):
|
||||
if (useOldCode):
|
||||
print(' ID | Name | Serial | UUID || GPU util. | Memory util. || Memory total | Memory used | Memory free || Display mode | Display active |')
|
||||
print('------------------------------------------------------------------------------------------------------------------------------')
|
||||
for gpu in GPUs:
|
||||
print(' {0:2d} | {1:s} | {2:s} | {3:s} || {4:3.0f}% | {5:3.0f}% || {6:.0f}MB | {7:.0f}MB | {8:.0f}MB || {9:s} | {10:s}'.format(gpu.id,gpu.name,gpu.serial,gpu.uuid,gpu.load*100,gpu.memoryUtil*100,gpu.memoryTotal,gpu.memoryUsed,gpu.memoryFree,gpu.display_mode,gpu.display_active))
|
||||
else:
|
||||
attrList = [[{'attr':'id','name':'ID'},
|
||||
{'attr':'name','name':'Name'},
|
||||
{'attr':'serial','name':'Serial'},
|
||||
{'attr':'uuid','name':'UUID'}],
|
||||
[{'attr':'temperature','name':'GPU temp.','suffix':'C','transform': lambda x: x,'precision':0},
|
||||
{'attr':'load','name':'GPU util.','suffix':'%','transform': lambda x: x*100,'precision':0},
|
||||
{'attr':'memoryUtil','name':'Memory util.','suffix':'%','transform': lambda x: x*100,'precision':0}],
|
||||
[{'attr':'memoryTotal','name':'Memory total','suffix':'MB','precision':0},
|
||||
{'attr':'memoryUsed','name':'Memory used','suffix':'MB','precision':0},
|
||||
{'attr':'memoryFree','name':'Memory free','suffix':'MB','precision':0}],
|
||||
[{'attr':'display_mode','name':'Display mode'},
|
||||
{'attr':'display_active','name':'Display active'}]]
|
||||
|
||||
else:
|
||||
if (useOldCode):
|
||||
print(' ID GPU MEM')
|
||||
print('--------------')
|
||||
for gpu in GPUs:
|
||||
print(' {0:2d} {1:3.0f}% {2:3.0f}%'.format(gpu.id, gpu.load*100, gpu.memoryUtil*100))
|
||||
else:
|
||||
attrList = [[{'attr':'id','name':'ID'},
|
||||
{'attr':'load','name':'GPU','suffix':'%','transform': lambda x: x*100,'precision':0},
|
||||
{'attr':'memoryUtil','name':'MEM','suffix':'%','transform': lambda x: x*100,'precision':0}],
|
||||
]
|
||||
|
||||
if (not useOldCode):
|
||||
if (attrList is not None):
|
||||
headerString = ''
|
||||
GPUstrings = ['']*len(GPUs)
|
||||
for attrGroup in attrList:
|
||||
#print(attrGroup)
|
||||
for attrDict in attrGroup:
|
||||
headerString = headerString + '| ' + attrDict['name'] + ' '
|
||||
headerWidth = len(attrDict['name'])
|
||||
minWidth = len(attrDict['name'])
|
||||
|
||||
attrPrecision = '.' + str(attrDict['precision']) if ('precision' in attrDict.keys()) else ''
|
||||
attrSuffix = str(attrDict['suffix']) if ('suffix' in attrDict.keys()) else ''
|
||||
attrTransform = attrDict['transform'] if ('transform' in attrDict.keys()) else lambda x : x
|
||||
for gpu in GPUs:
|
||||
attr = getattr(gpu,attrDict['attr'])
|
||||
|
||||
attr = attrTransform(attr)
|
||||
|
||||
if (isinstance(attr,float)):
|
||||
attrStr = ('{0:' + attrPrecision + 'f}').format(attr)
|
||||
elif (isinstance(attr,int)):
|
||||
attrStr = ('{0:d}').format(attr)
|
||||
elif (isinstance(attr,str)):
|
||||
attrStr = attr;
|
||||
elif (sys.version_info[0] == 2):
|
||||
if (isinstance(attr,unicode)):
|
||||
attrStr = attr.encode('ascii','ignore')
|
||||
else:
|
||||
raise TypeError('Unhandled object type (' + str(type(attr)) + ') for attribute \'' + attrDict['name'] + '\'')
|
||||
|
||||
attrStr += attrSuffix
|
||||
|
||||
minWidth = max(minWidth,len(attrStr))
|
||||
|
||||
headerString += ' '*max(0,minWidth-headerWidth)
|
||||
|
||||
minWidthStr = str(minWidth - len(attrSuffix))
|
||||
|
||||
for gpuIdx,gpu in enumerate(GPUs):
|
||||
attr = getattr(gpu,attrDict['attr'])
|
||||
|
||||
attr = attrTransform(attr)
|
||||
|
||||
if (isinstance(attr,float)):
|
||||
attrStr = ('{0:'+ minWidthStr + attrPrecision + 'f}').format(attr)
|
||||
elif (isinstance(attr,int)):
|
||||
attrStr = ('{0:' + minWidthStr + 'd}').format(attr)
|
||||
elif (isinstance(attr,str)):
|
||||
attrStr = ('{0:' + minWidthStr + 's}').format(attr);
|
||||
elif (sys.version_info[0] == 2):
|
||||
if (isinstance(attr,unicode)):
|
||||
attrStr = ('{0:' + minWidthStr + 's}').format(attr.encode('ascii','ignore'))
|
||||
else:
|
||||
raise TypeError('Unhandled object type (' + str(type(attr)) + ') for attribute \'' + attrDict['name'] + '\'')
|
||||
|
||||
attrStr += attrSuffix
|
||||
|
||||
GPUstrings[gpuIdx] += '| ' + attrStr + ' '
|
||||
|
||||
headerString = headerString + '|'
|
||||
for gpuIdx,gpu in enumerate(GPUs):
|
||||
GPUstrings[gpuIdx] += '|'
|
||||
|
||||
headerSpacingString = '-' * len(headerString)
|
||||
print(headerString)
|
||||
print(headerSpacingString)
|
||||
for GPUstring in GPUstrings:
|
||||
print(GPUstring)
|
||||
|
||||
|
||||
# Generate gpu uuid to id map
|
||||
gpuUuidToIdMap = {}
|
||||
try:
|
||||
gpus = getGPUs()
|
||||
for gpu in gpus:
|
||||
gpuUuidToIdMap[gpu.uuid] = gpu.id
|
||||
del gpus
|
||||
except:
|
||||
pass
|
||||
def getGPUInfos():
|
||||
###返回gpus:list,一个GPU为一个元素-对象
|
||||
###########:有属性,'id','load','memoryFree',
|
||||
###########:'memoryTotal','memoryUsed','memoryUtil','name','serial''temperature','uuid',process
|
||||
###其中process:每一个计算进程是一个元素--对象
|
||||
############:有属性,'gpuId','gpuName','gpuUuid',
|
||||
############:'gpuid','pid','processName','uid', 'uname','usedMemory'
|
||||
gpus = getGPUs()
|
||||
gpuUuidToIdMap={}
|
||||
for gpu in gpus:
|
||||
gpuUuidToIdMap[gpu.uuid] = gpu.id
|
||||
gpu.process=[]
|
||||
indexx = [x.id for x in gpus ]
|
||||
|
||||
process = getGPUProcesses()
|
||||
for pre in process:
|
||||
pre.gpuid = gpuUuidToIdMap[pre.gpuUuid]
|
||||
gpuId = indexx.index(pre.gpuid )
|
||||
gpus[gpuId].process.append(pre )
|
||||
return gpus
|
||||
|
||||
def get_available_gpu(gpuStatus):
|
||||
##判断是否有空闲的显卡,如果有返回id,没有返回None
|
||||
cuda=None
|
||||
for gpus in gpuStatus:
|
||||
if len(gpus.process) == 0:
|
||||
cuda = gpus.id
|
||||
return cuda
|
||||
return cuda
|
||||
def get_whether_gpuProcess():
|
||||
##判断是否有空闲的显卡,如果有返回id,没有返回None
|
||||
gpuStatus=getGPUInfos()
|
||||
gpuProcess=True
|
||||
for gpus in gpuStatus:
|
||||
if len(gpus.process) != 0:
|
||||
gpuProcess = False
|
||||
return gpuProcess
|
||||
|
||||
def get_offlineProcess_gpu(gpuStatus,pidInfos):
|
||||
gpu_onLine = []
|
||||
for gpu in gpuStatus:
|
||||
for gpuProcess in gpu.process:
|
||||
pid = gpuProcess.pid
|
||||
if pid in pidInfos.keys():
|
||||
pidType = pidInfos[pid]['type']
|
||||
if pidType == 'onLine':
|
||||
gpu_onLine.append(gpu)
|
||||
gpu_offLine = set(gpuStatus) - set(gpu_onLine)
|
||||
return list(gpu_offLine)
|
||||
def arrange_offlineProcess(gpuStatus,pidInfos,modelMemory=1500):
|
||||
cudaArrange=[]
|
||||
gpu_offLine = get_offlineProcess_gpu(gpuStatus,pidInfos)
|
||||
for gpu in gpu_offLine:
|
||||
leftMemory = gpu.memoryTotal*0.9 - gpu.memoryUsed
|
||||
modelCnt = int(leftMemory// modelMemory)
|
||||
|
||||
cudaArrange.extend( [gpu.id] * modelCnt )
|
||||
return cudaArrange
|
||||
def get_potential_gpu(gpuStatus,pidInfos):
|
||||
###所有GPU上都有计算。需要为“在线任务”空出一块显卡。
|
||||
###step1:查看所有显卡上是否有“在线任务”
|
||||
|
||||
gpu_offLine = get_offlineProcess_gpu(gpuStatus,pidInfos)
|
||||
if len(gpu_offLine) == 0 :
|
||||
return False
|
||||
|
||||
###step2,找出每张显卡上离线进程的数目
|
||||
offLineCnt = [ len(gpu.process) for gpu in gpu_offLine ]
|
||||
minCntIndex =offLineCnt.index( min(offLineCnt))
|
||||
|
||||
pids = [x.pid for x in gpu_offLine[minCntIndex].process]
|
||||
return {'cuda':gpu_offLine[minCntIndex].id,'pids':pids }
|
||||
if __name__=='__main__':
|
||||
#pres = getGPUProcesses()
|
||||
#print('###line404:',pres)
|
||||
gpus = getGPUs()
|
||||
for gpu in gpus:
|
||||
gpuUuidToIdMap[gpu.uuid] = gpu.id
|
||||
print(gpu)
|
||||
print(gpuUuidToIdMap)
|
||||
pres = getGPUProcesses()
|
||||
print('###line404:',pres)
|
||||
for pre in pres:
|
||||
print('#'*20)
|
||||
for ken in ['gpuName','gpuUuid','pid','processName','uid','uname','usedMemory' ]:
|
||||
print(ken,' ',pre.__getattribute__(ken ))
|
||||
print(' ')
|
||||
|
||||
|
||||