|
|
@@ -446,7 +446,7 @@ class ForestModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -463,7 +463,7 @@ def forest_process(param): |
|
|
|
""" |
|
|
|
|
|
|
|
return AI_process_forest([param[0]], param[1], param[2], param[3], param[4], param[5], param[6], param[7], |
|
|
|
param[8], param[9], param[10], font=param[11], trtFlag_det=param[12]) |
|
|
|
param[8], param[9], param[10], font=param[11], trtFlag_det=param[12], SecNms=param[14]) |
|
|
|
except ServiceException as s: |
|
|
|
raise s |
|
|
|
except Exception: |
|
|
@@ -527,7 +527,7 @@ class VehicleModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -609,7 +609,7 @@ class PedestrianModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -671,7 +671,7 @@ class SmogfireModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -733,7 +733,7 @@ class AnglerSwimmerModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -796,7 +796,7 @@ class CountryRoadModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -859,7 +859,7 @@ class ChannelEmergencyModel: |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId] |
|
|
|
trtFlag_det, requestId, None] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
@@ -868,6 +868,70 @@ class ChannelEmergencyModel: |
|
|
|
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId) |
|
|
|
|
|
|
|
|
|
|
|
# 城管模型 |
|
|
|
class CityMangementModel: |
|
|
|
__slots__ = "model_conf" |
|
|
|
|
|
|
|
def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None): |
|
|
|
s = time.time() |
|
|
|
try: |
|
|
|
logger.info("########################加载城管模型########################, requestId:{}", requestId) |
|
|
|
trtFlag_det = True, # 检测模型是否采用TRT |
|
|
|
# trtFlag_seg = False, # 分割模型是否采用TRT |
|
|
|
# 公共变量 |
|
|
|
par = { |
|
|
|
'device': str(device1), |
|
|
|
'gpu_name': gpu_name, |
|
|
|
'labelnames': ["车辆", "垃圾"], |
|
|
|
'seg_nclass': 2, # 分割模型类别数目,默认2类 |
|
|
|
'segRegionCnt': 0, |
|
|
|
'slopeIndex': [], |
|
|
|
'segPar': None, |
|
|
|
'postFile': { |
|
|
|
"name": "post_process", |
|
|
|
"conf_thres": 0.25, |
|
|
|
"iou_thres": 0.45, |
|
|
|
"ovlap_thres_crossCategory": 0.6, |
|
|
|
"classes": 2, |
|
|
|
"rainbows": COLOR |
|
|
|
}, |
|
|
|
'segweights': None |
|
|
|
} |
|
|
|
if trtFlag_det: |
|
|
|
par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name']) |
|
|
|
else: |
|
|
|
par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3] |
|
|
|
|
|
|
|
device = select_device(par.get('device')) # 指定GPU |
|
|
|
names = par.get('labelnames') |
|
|
|
half = device.type != 'cpu' |
|
|
|
Detweights = par.get('detweights') # 升级后的检测模型 |
|
|
|
if trtFlag_det: |
|
|
|
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime: |
|
|
|
model = runtime.deserialize_cuda_engine(f.read()) |
|
|
|
else: |
|
|
|
model = attempt_load(Detweights, map_location=device) |
|
|
|
if half: |
|
|
|
model.half() |
|
|
|
segmodel = None |
|
|
|
conf_thres = par.get('postFile').get("conf_thres") |
|
|
|
iou_thres = par.get('postFile').get("iou_thres") |
|
|
|
# classes = par.get('postFile').get("classes") |
|
|
|
rainbows = par.get('postFile').get("rainbows") |
|
|
|
ovlap_thres_crossCategory = par.get('postFile').get("ovlap_thres_crossCategory") |
|
|
|
""" |
|
|
|
frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres, |
|
|
|
allowedList, digitFont, trtFlag_det, requestId |
|
|
|
""" |
|
|
|
model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None, |
|
|
|
trtFlag_det, requestId, ovlap_thres_crossCategory] |
|
|
|
self.model_conf = (modeType, allowedList, model_param) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], |
|
|
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) |
|
|
|
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId) |
|
|
|
|
|
|
|
# 船只模型 |
|
|
|
class ShipModel: |
|
|
|
__slots__ = "model_conf" |
|
|
@@ -1232,6 +1296,17 @@ def smogfire_label(width, model_param): |
|
|
|
model_param[4] = label_arraylist |
|
|
|
model_param[11] = digitFont |
|
|
|
|
|
|
|
def city_mangement_label(width, model_param): |
|
|
|
names = model_param[3] |
|
|
|
rainbows = model_param[5] |
|
|
|
digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows) |
|
|
|
""" |
|
|
|
frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres, |
|
|
|
allowedList, digitFont, trtFlag_det, requestId, over |
|
|
|
""" |
|
|
|
model_param[4] = label_arraylist |
|
|
|
model_param[11] = digitFont |
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
|
# 加载河道模型 |
|
|
@@ -1321,7 +1396,13 @@ MODEL_CONFIG = { |
|
|
|
lambda x, y, r, t, z: River2Model(x, y, r, ModelType.RIVER2_MODEL, t, z), |
|
|
|
ModelType.RIVER2_MODEL, |
|
|
|
lambda x, y: river2_label(x, y), |
|
|
|
lambda x: model_process(x) |
|
|
|
lambda x: model_process(x)), |
|
|
|
# 城管模型 |
|
|
|
ModelType.CITY_MANGEMENT_MODEL.value[1]: ( |
|
|
|
lambda x, y, r, t, z: CityMangementModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z), |
|
|
|
ModelType.CITY_MANGEMENT_MODEL, |
|
|
|
lambda x, y: city_mangement_label(x, y), |
|
|
|
lambda x: forest_process(x) |
|
|
|
) |
|
|
|
} |
|
|
|
# ModelConfig = namedtuple('ModelConfig', ('x', 'y', 'z')) |