# -*- coding: utf-8 -*- import sys from pickle import dumps, loads from traceback import format_exc import cv2 from loguru import logger from enums.ExceptionEnum import ExceptionType from enums.ModelTypeEnum import ModelType from exception.CustomerException import ServiceException from util.PlotsUtils import get_label_arrays from util.TorchUtils import select_device sys.path.extend(['..', '../AIlib2']) from AI import AI_process from stdc import stdcModel from models.experimental import attempt_load import tensorrt as trt from DMPR import DMPRModel FONT_PATH = "../AIlib2/conf/platech.ttf" def get_label_arraylist(*args): width, height, names, rainbows = args # line = int(round(0.002 * (height + width) / 2) + 1) line = max(1, int(round(width / 1920 * 3))) label = ' 0.95' tf = max(line - 1, 1) fontScale = line * 0.33 text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0] # fontsize = int(width / 1920 * 40) numFontSize = float(format(width / 1920 * 1.1, '.1f')) digitFont = {'line_thickness': line, 'boxLine_thickness': line, 'fontSize': numFontSize, 'waterLineColor': (0, 255, 255), 'segLineShow': False, 'waterLineWidth': line, 'wordSize': text_height, 'label_location': 'leftTop'} label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH) return digitFont, label_arraylist, (line, text_width, text_height, fontScale, tf) # 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型 class OneModel: __slots__ = "model_conf" def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None): try: logger.info("########################加载{}########################, requestId:{}", modeType.value[2], requestId) par = modeType.value[4](str(device), gpu_name) mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar') names = par['labelnames'] postFile = par['postFile'] rainbows = postFile["rainbows"] new_device = select_device(par.get('device')) half = new_device.type != 'cpu' Detweights = par['Detweights'] if par['trtFlag_det']: with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime: model = runtime.deserialize_cuda_engine(f.read()) else: model = attempt_load(Detweights, map_location=new_device) # load FP32 model if half: model.half() par['segPar']['seg_nclass'] = par['seg_nclass'] Segweights = par['Segweights'] if Segweights: if modeType.value[3] == 'cityMangement3': segmodel = DMPRModel(weights=Segweights, par=par['segPar']) else: segmodel = stdcModel(weights=Segweights, par=par['segPar']) else: segmodel = None objectPar = { 'half': half, 'device': new_device, 'conf_thres': postFile["conf_thres"], 'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"), 'iou_thres': postFile["iou_thres"], # 对高速模型进行过滤 'allowedList': par['allowedList'] if modeType.value[0] == '3' else [], 'segRegionCnt': par['segRegionCnt'], 'trtFlag_det': par['trtFlag_det'], 'trtFlag_seg': par['trtFlag_seg'] } model_param = { "model": model, "segmodel": segmodel, "objectPar": objectPar, "segPar": segPar, "mode": mode, "postPar": postPar } self.model_conf = (modeType, model_param, allowedList, names, rainbows) except Exception: logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) def model_process(args): model_conf, frame, request_id = args model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4] # modeType, model_param, allowedList, names, rainbows = model_conf # segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args # model_param['digitFont'] = digitFont # model_param['label_arraylist'] = label_arraylist # model_param['font_config'] = font_config try: return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'], rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'], segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'], postPar=model_param['postPar']) except ServiceException as s: raise s except Exception: # self.num += 1 # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame) logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) def one_label(width, height, model_conf): # modeType, model_param, allowedList, names, rainbows = model_conf names = model_conf[3] rainbows = model_conf[4] model_param = model_conf[1] digitFont, label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows) model_param['digitFont'] = digitFont model_param['label_arraylist'] = label_arraylist model_param['font_config'] = font_config MODEL_CONFIG = { # 河道检测模型 ModelType.RIVER2_MODEL.value[1]: ( lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.RIVER2_MODEL, t, z, h), ModelType.RIVER2_MODEL, lambda x, y, z: one_label(x, y, z), lambda x: model_process(x)), }