143 lines
6.3 KiB
Python
143 lines
6.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
import sys
|
|
from pickle import dumps, loads
|
|
from traceback import format_exc
|
|
|
|
import cv2
|
|
from loguru import logger
|
|
|
|
from enums.ExceptionEnum import ExceptionType
|
|
from enums.ModelTypeEnum import ModelType
|
|
from exception.CustomerException import ServiceException
|
|
from util.PlotsUtils import get_label_arrays
|
|
from util.TorchUtils import select_device
|
|
|
|
sys.path.extend(['..', '../AIlib2'])
|
|
from AI import AI_process
|
|
from stdc import stdcModel
|
|
from models.experimental import attempt_load
|
|
import tensorrt as trt
|
|
from DMPR import DMPRModel
|
|
FONT_PATH = "../AIlib2/conf/platech.ttf"
|
|
|
|
def get_label_arraylist(*args):
|
|
width, height, names, rainbows = args
|
|
# line = int(round(0.002 * (height + width) / 2) + 1)
|
|
line = max(1, int(round(width / 1920 * 3)))
|
|
label = ' 0.95'
|
|
tf = max(line - 1, 1)
|
|
fontScale = line * 0.33
|
|
text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
|
|
# fontsize = int(width / 1920 * 40)
|
|
numFontSize = float(format(width / 1920 * 1.1, '.1f'))
|
|
digitFont = {'line_thickness': line,
|
|
'boxLine_thickness': line,
|
|
'fontSize': numFontSize,
|
|
'waterLineColor': (0, 255, 255),
|
|
'segLineShow': False,
|
|
'waterLineWidth': line,
|
|
'wordSize': text_height,
|
|
'label_location': 'leftTop'}
|
|
label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
|
|
return digitFont, label_arraylist, (line, text_width, text_height, fontScale, tf)
|
|
|
|
# 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型
|
|
class OneModel:
|
|
__slots__ = "model_conf"
|
|
|
|
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
|
try:
|
|
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
|
requestId)
|
|
par = modeType.value[4](str(device), gpu_name)
|
|
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
|
|
names = par['labelnames']
|
|
postFile = par['postFile']
|
|
rainbows = postFile["rainbows"]
|
|
new_device = select_device(par.get('device'))
|
|
half = new_device.type != 'cpu'
|
|
Detweights = par['Detweights']
|
|
if par['trtFlag_det']:
|
|
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
|
|
model = runtime.deserialize_cuda_engine(f.read())
|
|
else:
|
|
model = attempt_load(Detweights, map_location=new_device) # load FP32 model
|
|
if half: model.half()
|
|
par['segPar']['seg_nclass'] = par['seg_nclass']
|
|
Segweights = par['Segweights']
|
|
if Segweights:
|
|
if modeType.value[3] == 'cityMangement3':
|
|
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
|
|
else:
|
|
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
|
|
else:
|
|
segmodel = None
|
|
objectPar = {
|
|
'half': half,
|
|
'device': new_device,
|
|
'conf_thres': postFile["conf_thres"],
|
|
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
|
|
'iou_thres': postFile["iou_thres"],
|
|
# 对高速模型进行过滤
|
|
'allowedList': par['allowedList'] if modeType.value[0] == '3' else [],
|
|
'segRegionCnt': par['segRegionCnt'],
|
|
'trtFlag_det': par['trtFlag_det'],
|
|
'trtFlag_seg': par['trtFlag_seg']
|
|
}
|
|
model_param = {
|
|
"model": model,
|
|
"segmodel": segmodel,
|
|
"objectPar": objectPar,
|
|
"segPar": segPar,
|
|
"mode": mode,
|
|
"postPar": postPar
|
|
}
|
|
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
|
except Exception:
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
|
|
|
def model_process(args):
|
|
model_conf, frame, request_id = args
|
|
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
|
|
# modeType, model_param, allowedList, names, rainbows = model_conf
|
|
# segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args
|
|
# model_param['digitFont'] = digitFont
|
|
# model_param['label_arraylist'] = label_arraylist
|
|
# model_param['font_config'] = font_config
|
|
try:
|
|
return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
|
|
rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
|
|
segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
|
|
postPar=model_param['postPar'])
|
|
except ServiceException as s:
|
|
raise s
|
|
except Exception:
|
|
# self.num += 1
|
|
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
|
|
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
|
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
|
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
|
|
|
def one_label(width, height, model_conf):
|
|
# modeType, model_param, allowedList, names, rainbows = model_conf
|
|
names = model_conf[3]
|
|
rainbows = model_conf[4]
|
|
model_param = model_conf[1]
|
|
digitFont, label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
|
|
model_param['digitFont'] = digitFont
|
|
model_param['label_arraylist'] = label_arraylist
|
|
model_param['font_config'] = font_config
|
|
|
|
|
|
MODEL_CONFIG = {
|
|
# 河道检测模型
|
|
ModelType.RIVER2_MODEL.value[1]: (
|
|
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.RIVER2_MODEL, t, z, h),
|
|
ModelType.RIVER2_MODEL,
|
|
lambda x, y, z: one_label(x, y, z),
|
|
lambda x: model_process(x)),
|
|
|
|
}
|