diff --git a/common/Constant.py b/common/Constant.py index c147670..1cd515d 100644 --- a/common/Constant.py +++ b/common/Constant.py @@ -16,7 +16,6 @@ success_progess = "1.0000" width = 1400 COLOR = ( - [0, 0, 255], [255, 0, 0], [211, 0, 148], [0, 127, 0], @@ -35,7 +34,8 @@ COLOR = ( [8, 101, 139], [171, 130, 255], [139, 112, 74], - [205, 205, 180]) + [205, 205, 180], + [0, 0, 255],) ONLINE = "online" OFFLINE = "offline" diff --git a/concurrency/FileUploadThread.py b/concurrency/FileUploadThread.py index 1be5139..5930754 100644 --- a/concurrency/FileUploadThread.py +++ b/concurrency/FileUploadThread.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from threading import Thread from time import sleep, time from traceback import format_exc +import numpy as np from loguru import logger import cv2 @@ -14,18 +15,18 @@ from util.AliyunSdk import AliyunOssSdk from util.MinioSdk import MinioSdk from util import TimeUtils from enums.AnalysisStatusEnum import AnalysisStatus -from util.PlotsUtils import draw_painting_joint, draw_name_ocr, draw_name_crowd +from util.PlotsUtils import draw_painting_joint, draw_name_ocr, draw_name_crowd,draw_transparent_red_polygon from util.QueUtil import put_queue, get_no_block_queue, clear_queue import io from util.LocationUtils import locate_byMqtt class FileUpload(Thread): - __slots__ = ('_fb_queue', '_context', '_image_queue', '_analyse_type', '_msg', '_mqtt_list') + __slots__ = ('_fb_queue', '_context', '_image_queue', '_analyse_type', '_msg') def __init__(self, *args): super().__init__() - self._fb_queue, self._context, self._msg, self._image_queue, self._analyse_type, self._mqtt_list = args + self._fb_queue, self._context, self._msg, self._image_queue, self._analyse_type = args self._storage_source = self._context['service']['storage_source'] self._algStatus = False # 默认关闭 @@ -64,16 +65,29 @@ class ImageFileUpload(FileUpload): 模型编号:modeCode 检测目标:detectTargetCode ''' - print('*' * 100, ' mqtt_list:', len(self._mqtt_list)) - + aFrame = frame.copy() + igH, igW = aFrame.shape[0:2] model_info = [] + mqttPares= det_xywh['mqttPares'] + border = None + gps = [None, None] + camParas = None + if mqttPares is not None: + if mqttPares[0] == 1: + border = mqttPares[1] + elif mqttPares[0] == 0: + camParas = mqttPares[1] + if border is not None: + aFrame = draw_transparent_red_polygon(aFrame, np.array(border, np.int32), alpha=0.25) + det_xywh.pop('mqttPares') # 更加模型编码解析数据 for code, det_list in det_xywh.items(): if len(det_list) > 0: for cls, target_list in det_list.items(): if len(target_list) > 0: - aFrame = frame.copy() for target in target_list: + if camParas is not None: + gps = locate_byMqtt(target[1], igW, igH, camParas, outFormat='wgs84') # 自研车牌模型判断 if ModelType.CITY_CARPLATE_MODEL.value[1] == str(code): draw_name_ocr(target[1], aFrame, target[4]) @@ -82,15 +96,7 @@ class ImageFileUpload(FileUpload): draw_name_crowd(target[1], aFrame, target[4]) else: draw_painting_joint(target[1], aFrame, target[3], target[2], target[4], font_config, - target[5]) - - igH, igW = aFrame.shape[0:2] - if len(self._mqtt_list) >= 1: - # camParas = self._mqtt_list[0]['data'] - camParas = self._mqtt_list[0] - gps = locate_byMqtt(target[1], igW, igH, camParas, outFormat='wgs84') - else: - gps = [None, None] + target[5],border) model_info.append( {"modelCode": str(code), "detectTargetCode": str(cls), "aFrame": aFrame, 'gps': gps}) if len(model_info) > 0: @@ -133,7 +139,6 @@ class ImageFileUpload(FileUpload): # 获取队列中的消息 image_msg = get_no_block_queue(image_queue) if image_msg is not None: - if image_msg[0] == 2: logger.info("图片上传线程收到命令:{}, requestId: {}", image_msg[1], request_id) if 'stop' == image_msg[1]: diff --git a/concurrency/IntelligentRecognitionProcess.py b/concurrency/IntelligentRecognitionProcess.py index 663878b..a4856d4 100644 --- a/concurrency/IntelligentRecognitionProcess.py +++ b/concurrency/IntelligentRecognitionProcess.py @@ -64,6 +64,7 @@ class IntelligentRecognitionProcess(Process): self._analyse_type, progress=init_progess), timeout=2, is_ex=True) self._storage_source = self._context['service']['storage_source'] self._algStatus = False + def sendEvent(self, eBody): put_queue(self.event_queue, eBody, timeout=2, is_ex=True) @@ -127,6 +128,7 @@ class OnlineIntelligentRecognitionProcess(IntelligentRecognitionProcess): or_url = upload_video_thread_or.get_result() ai_url = upload_video_thread_ai.get_result() return or_url, ai_url + ''' @staticmethod def upload_video(base_dir, env, request_id, orFilePath, aiFilePath): @@ -306,7 +308,7 @@ class OnlineIntelligentRecognitionProcess(IntelligentRecognitionProcess): else: model_param = model_conf[1] # (modeType, model_param, allowedList, names, rainbows) - MODEL_CONFIG[code][2](frame_list[0].shape[1], frame_list[0].shape[0], + MODEL_CONFIG[code][2](frame_list[0][0].shape[1], frame_list[0][0].shape[0], model_conf) if draw_config.get("font_config") is None: draw_config["font_config"] = model_param['font_config'] @@ -322,7 +324,7 @@ class OnlineIntelligentRecognitionProcess(IntelligentRecognitionProcess): # print_cpu_status(requestId=request_id,lineNum=inspect.currentframe().f_lineno) # 多线程并发处理, 经过测试两个线程最优 det_array = [] - for i, frame in enumerate(frame_list): + for i, [frame,_] in enumerate(frame_list): det_result = t.submit(self.obj_det, self, model_array, frame, task_status, frame_index_list[i], tt, request_id) det_array.append(det_result) @@ -469,6 +471,7 @@ class OfflineIntelligentRecognitionProcess(IntelligentRecognitionProcess): ai_url = upload_video_thread_ai.get_result() return ai_url ''' + @staticmethod def ai_normal_dtection(model, frame, request_id): model_conf, code = model @@ -634,7 +637,7 @@ class OfflineIntelligentRecognitionProcess(IntelligentRecognitionProcess): else: model_param = model_conf[1] # (modeType, model_param, allowedList, names, rainbows) - MODEL_CONFIG[code][2](frame_list[0].shape[1], frame_list[0].shape[0], + MODEL_CONFIG[code][2](frame_list[0][0].shape[1], frame_list[0][0].shape[0], model_conf) if draw_config.get("font_config") is None: draw_config["font_config"] = model_param['font_config'] @@ -648,7 +651,7 @@ class OfflineIntelligentRecognitionProcess(IntelligentRecognitionProcess): det_array = [] - for i, frame in enumerate(frame_list): + for i, [frame,_] in enumerate(frame_list): det_result = t.submit(self.obj_det, self, model_array, frame, task_status, frame_index_list[i], tt, request_id) det_array.append(det_result) @@ -1051,6 +1054,8 @@ class PhotosIntelligentRecognitionProcess(Process): ai_result_list = p_result[2] for ai_result in ai_result_list: box, score, cls = xywh2xyxy2(ai_result) + if ModelType.CITY_FIREAREA_MODEL.value[1] == str(code): + box.append(ai_result[-1]) # 如果检测目标在识别任务中,继续处理 if cls in allowedList: label_array = label_arraylist[cls] @@ -1208,7 +1213,8 @@ class PhotosIntelligentRecognitionProcess(Process): image_thread.setDaemon(True) image_thread.start() return image_thread - def check_ImageUrl_Vaild(self,url,timeout=1): + + def check_ImageUrl_Vaild(self, url, timeout=1): try: # 发送 HTTP 请求,尝试访问图片 response = requests.get(url, timeout=timeout) # 设置超时时间为 10 秒 @@ -1239,8 +1245,7 @@ class PhotosIntelligentRecognitionProcess(Process): ExceptionType.URL_ADDRESS_ACCESS_FAILED.value[0], ExceptionType.URL_ADDRESS_ACCESS_FAILED.value[1]), timeout=2) - return - + return with ThreadPoolExecutor(max_workers=1) as t: try: @@ -1318,6 +1323,7 @@ class ScreenRecordingProcess(Process): recording_feedback(self._msg["request_id"], RecordingStatus.RECORDING_WAITING.value[0]), timeout=1, is_ex=True) self._storage_source = self._context['service']['storage_source'] + def sendEvent(self, result): put_queue(self._event_queue, result, timeout=2, is_ex=True) @@ -1495,6 +1501,7 @@ class ScreenRecordingProcess(Process): upload_video_thread_ai.start() or_url = upload_video_thread_ai.get_result() return or_url + ''' @staticmethod def upload_video(base_dir, env, request_id, orFilePath): diff --git a/concurrency/PullMqttThread.py b/concurrency/PullMqttThread.py index fe208fb..9a39091 100644 --- a/concurrency/PullMqttThread.py +++ b/concurrency/PullMqttThread.py @@ -1,142 +1,163 @@ -# -*- coding: utf-8 -*- -from threading import Thread -from time import sleep, time -from traceback import format_exc - -from loguru import logger -from common.YmlConstant import mqtt_yml_path -from util.RWUtils import getConfigs -from common.Constant import init_progess -from enums.AnalysisStatusEnum import AnalysisStatus -from entity.FeedBack import message_feedback -from enums.ExceptionEnum import ExceptionType -from exception.CustomerException import ServiceException -from util.QueUtil import get_no_block_queue, put_queue, clear_queue -from multiprocessing import Process, Queue -import paho.mqtt.client as mqtt -import json,os -class PullMqtt(Thread): - __slots__ = ('__fb_queue', '__mqtt_list', '__request_id', '__analyse_type', "_context") - - def __init__(self, *args): - super().__init__() - self.__fb_queue, self.__mqtt_list, self.__request_id, self.__analyse_type, self._context = args - - base_dir, env = self._context["base_dir"], self._context["env"] - self.__config = getConfigs(os.path.join(base_dir, mqtt_yml_path % env)) - - self.__broker = self.__config["broker"] - self.__port = self.__config["port"] - self.__topic = self.__config["topic"] - self.__lengthMqttList = self.__config["length"] - - - def put_queue(self,__queue,data): - if __queue.full(): - a = __queue.get() - __queue.put( data,block=True, timeout=2 ) - def on_connect(self,client,userdata,flags,rc): - client.subscribe(self.__topic) - - - - # 当接收到MQTT消息时,回调函数 - def on_message(self,client, userdata, msg): - # 将消息解码为JSON格式 - payload = msg.payload.decode('utf-8') - data = json.loads(payload) - #logger.info(str(data)) - - - # 解析位姿信息 - lon = data.get("lon") - lat = data.get("lat") - alt = data.get("alt") - yaw = data.get("yaw") - pitch = data.get("pitch") - roll = data.get("roll") - - if len(self.__mqtt_list) == self.__lengthMqttList: - self.__mqtt_list.pop(0) - self.__mqtt_list.append(data) - - - # 打印无人机的位姿信息 - #print(f"Longitude: {lon}, Latitude: {lat}, Altitude: {alt}, sat:{data.get('satcount')} , list length:{len(self.__mqtt_list)}") - - def mqtt_connect(self): - # 创建客户端 - self.client = mqtt.Client() - self.client.on_connect = self.on_connect - # 设置回调函数 - self.client.on_message = self.on_message - - # 连接到 Broker - self.client.connect(self.__broker, self.__port) - - # 订阅主题 - self.client.subscribe(self.__topic) - # 循环等待并处理网络事件 - self.client.loop_forever() - - def mqtt_disconnect(self): - start_time = time() - while True: - if time() - start_time > service_timeout: - logger.error("MQTT读取超时, requestId: %s,限定时间:%.1s , 已运行:%.1fs"%(request_id,service_timeout, time() - start_time)) - raise ServiceException(ExceptionType.TASK_EXCUTE_TIMEOUT.value[0], - ExceptionType.TASK_EXCUTE_TIMEOUT.value[1]) - client.loop_stop() # 停止循环 - client.disconnect() # 断开连接 - - def run(self): - request_id, mqtt_list, progress = self.__request_id, self.__mqtt_list, init_progess - analyse_type, fb_queue = self.__analyse_type, self.__fb_queue - #service_timeout = int(self.__config["service"]["timeout"]) + 120 - - try: - logger.info("开始MQTT读取线程!requestId:{}", request_id) - mqtt_init_num = 0 - self.mqtt_connect() - - except Exception: - logger.error("MQTT线程异常:{}, requestId:{}", format_exc(), request_id) - finally: - mqtt_list = [] - logger.info("MQTT线程停止完成!requestId:{}", request_id) - - -def start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context): - mqtt_thread = PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) - mqtt_thread.setDaemon(True) - mqtt_thread.start() - return mqtt_thread -def start_PullVideo(mqtt_list): - for i in range(1000): - sleep(1) - if len(mqtt_list)>=10: - print( mqtt_list[4]) - print(i,len(mqtt_list)) -if __name__=="__main__": - #context = {'service':{'timeout':3600},'mqtt':{ - # 'broker':"101.133.163.127",'port':1883,'topic':"test/topic","length":10} - # } - context = { - 'base_dir':'/home/th/WJ/test/tuoheng_algN', - 'env':'test' - - } - analyse_type = '1' - request_id = '123456789' - event_queue, pull_queue, mqtt_list, image_queue, push_queue, push_ex_queue = Queue(), Queue(10), [], Queue(), Queue(), Queue() - fb_queue = Queue() - mqtt_thread = start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) - - - start_PullVideo(mqtt_list) - print('---line117--') - - - - #mqtt_thread.join() +# -*- coding: utf-8 -*- +from threading import Thread +from time import sleep, time +from traceback import format_exc + +from loguru import logger +from common.YmlConstant import mqtt_yml_path +from util.RWUtils import getConfigs +from common.Constant import init_progess +from enums.AnalysisStatusEnum import AnalysisStatus +from entity.FeedBack import message_feedback +from enums.ExceptionEnum import ExceptionType +from exception.CustomerException import ServiceException +from util.QueUtil import get_no_block_queue, put_queue, clear_queue +from multiprocessing import Process, Queue +import paho.mqtt.client as mqtt +import json,os +class PullMqtt(Thread): + __slots__ = ('__fb_queue', '__mqtt_list', '__request_id', '__analyse_type', "_context" ,'__business') + + def __init__(self, *args): + super().__init__() + self.__fb_queue, self.__mqtt_list, self.__request_id, self.__analyse_type, self._context, self.__business = args + + base_dir, env = self._context["base_dir"], self._context["env"] + self.__config = getConfigs(os.path.join(base_dir, mqtt_yml_path % env)) + if self.__business == 0: + self.__broker = self.__config['location']["broker"] + self.__port = self.__config['location']["port"] + self.__topic = self.__config['location']["topic"] + elif self.__business == 1: + self.__broker = self.__config['invade']["broker"] + self.__port = self.__config['invade']["port"] + self.__topic = self.__config['invade']["topic"] + self.__lengthMqttList = self.__config["length"] + + + def put_queue(self,__queue,data): + if __queue.full(): + a = __queue.get() + __queue.put( data,block=True, timeout=2 ) + def on_connect(self,client,userdata,flags,rc): + client.subscribe(self.__topic) + + + + # 当接收到MQTT消息时,回调函数 + def on_location(self,client, userdata, msg): + # 将消息解码为JSON格式 + payload = msg.payload.decode('utf-8') + data = json.loads(payload) + #logger.info(str(data)) + # 解析位姿信息 + lon = data.get("lon") + lat = data.get("lat") + alt = data.get("alt") + yaw = data.get("yaw") + pitch = data.get("pitch") + roll = data.get("roll") + + if len(self.__mqtt_list) == self.__lengthMqttList: + self.__mqtt_list.pop(0) + self.__mqtt_list.append([self.__business,data]) + + + # 打印无人机的位姿信息 + #print(f"Longitude: {lon}, Latitude: {lat}, Altitude: {alt}, sat:{data.get('satcount')} , list length:{len(self.__mqtt_list)}") + + def on_invade(self, client, userdata, msg): + # 将消息解码为JSON格式 + payload = msg.payload.decode('utf-8') + data = json.loads(payload) + # logger.info(str(data)) + # 解析位姿信息 + points = data.get("points") + + if len(self.__mqtt_list) == self.__lengthMqttList: + self.__mqtt_list.pop(0) + self.__mqtt_list.append([self.__business,points]) + + # 打印无人机的位姿信息 + # print(f"Longitude: {lon}, Latitude: {lat}, Altitude: {alt}, sat:{data.get('satcount')} , list length:{len(self.__mqtt_list)}") + + def mqtt_connect(self): + # 创建客户端 + self.client = mqtt.Client() + self.client.on_connect = self.on_connect + if self.__business == 0: + # 设置回调函数 + self.client.on_message = self.on_location + elif self.__business == 1: + # 设置回调函数 + self.client.on_message = self.on_invade + + # 连接到 Broker + self.client.connect(self.__broker, self.__port) + + # 订阅主题 + self.client.subscribe(self.__topic) + # 循环等待并处理网络事件 + self.client.loop_forever() + + def mqtt_disconnect(self): + start_time = time() + while True: + if time() - start_time > service_timeout: + logger.error("MQTT读取超时, requestId: %s,限定时间:%.1s , 已运行:%.1fs"%(request_id,service_timeout, time() - start_time)) + raise ServiceException(ExceptionType.TASK_EXCUTE_TIMEOUT.value[0], + ExceptionType.TASK_EXCUTE_TIMEOUT.value[1]) + client.loop_stop() # 停止循环 + client.disconnect() # 断开连接 + + def run(self): + request_id, mqtt_list, progress = self.__request_id, self.__mqtt_list, init_progess + analyse_type, fb_queue = self.__analyse_type, self.__fb_queue + #service_timeout = int(self.__config["service"]["timeout"]) + 120 + + try: + logger.info("开始MQTT读取线程!requestId:{}", request_id) + mqtt_init_num = 0 + self.mqtt_connect() + + except Exception: + logger.error("MQTT线程异常:{}, requestId:{}", format_exc(), request_id) + finally: + mqtt_list = [] + logger.info("MQTT线程停止完成!requestId:{}", request_id) + + +def start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context): + mqtt_thread = PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) + mqtt_thread.setDaemon(True) + mqtt_thread.start() + return mqtt_thread +def start_PullVideo(mqtt_list): + for i in range(1000): + sleep(1) + if len(mqtt_list)>=10: + print( mqtt_list[4]) + print(i,len(mqtt_list)) +if __name__=="__main__": + #context = {'service':{'timeout':3600},'mqtt':{ + # 'broker':"101.133.163.127",'port':1883,'topic':"test/topic","length":10} + # } + context = { + 'base_dir':'/home/th/WJ/test/tuoheng_algN', + 'env':'test' + + } + analyse_type = '1' + request_id = '123456789' + event_queue, pull_queue, mqtt_list, image_queue, push_queue, push_ex_queue = Queue(), Queue(10), [], Queue(), Queue(), Queue() + fb_queue = Queue() + mqtt_thread = start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) + + + start_PullVideo(mqtt_list) + print('---line117--') + + + + #mqtt_thread.join() \ No newline at end of file diff --git a/concurrency/PullVideoStreamProcess.py b/concurrency/PullVideoStreamProcess.py index 54fe030..933438a 100644 --- a/concurrency/PullVideoStreamProcess.py +++ b/concurrency/PullVideoStreamProcess.py @@ -35,15 +35,15 @@ class PullVideoStreamProcess(Process): put_queue(self._command_queue, result, timeout=2, is_ex=True) @staticmethod - def start_File_upload(fb_queue, context, msg, image_queue, analyse_type,mqtt_list): - image_thread = ImageFileUpload(fb_queue, context, msg, image_queue, analyse_type,mqtt_list) + def start_File_upload(fb_queue, context, msg, image_queue, analyse_type): + image_thread = ImageFileUpload(fb_queue, context, msg, image_queue, analyse_type) image_thread.setDaemon(True) image_thread.start() return image_thread @staticmethod - def start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context): - mqtt_thread = PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) + def start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context,business): + mqtt_thread = PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context,business) mqtt_thread.setDaemon(True) mqtt_thread.start() return mqtt_thread @@ -81,13 +81,14 @@ class OnlinePullVideoStreamProcess(PullVideoStreamProcess): # 初始化日志 init_log(base_dir, env) logger.info("开启启动实时视频拉流进程, requestId:{},pid:{},ppid:{}", request_id,os.getpid(),os.getppid()) - - #开启mqtt - if service["mqtt_flag"]==1: - mqtt_thread = self.start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context) + + # 开启mqtt + if service['mqtt']["flag"] == 1: + business = service['mqtt']["business"] + mqtt_thread = self.start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context, business) # 开启图片上传线程 - image_thread = self.start_File_upload(fb_queue, context, msg, image_queue, analyse_type,mqtt_list) + image_thread = self.start_File_upload(fb_queue, context, msg, image_queue, analyse_type) cv2_init_num, init_pull_num, concurrent_frame = 0, 1, 1 start_time, pull_stream_start_time, read_start_time, full_timeout = time(), None, None, None while True: @@ -129,7 +130,7 @@ class OnlinePullVideoStreamProcess(PullVideoStreamProcess): frame, pull_p, width, height = pull_read_video_stream(pull_p, pull_url, width, height, width_height_3, w_2, h_2, request_id) if pull_queue.full(): - logger.info("pull拉流队列满了:{}, requestId: {}", os.getppid(), request_id) + #logger.info("pull拉流队列满了:{}, requestId: {}", os.getppid(), request_id) if full_timeout is None: full_timeout = time() if time() - full_timeout > 180: @@ -171,7 +172,7 @@ class OnlinePullVideoStreamProcess(PullVideoStreamProcess): sleep(1) continue init_pull_num, read_start_time = 1, None - frame_list.append(frame) + frame_list.append([frame, mqtt_list]) frame_index_list.append(concurrent_frame) if len(frame_list) >= frame_num: put_queue(pull_queue, (4, (frame_list, frame_index_list, all_frames)), timeout=1, is_ex=True) @@ -222,10 +223,11 @@ class OfflinePullVideoStreamProcess(PullVideoStreamProcess): def run(self): msg, context, frame_num, analyse_type = self._msg, self._context, self._frame_num, self._analyse_type - request_id, base_dir, env, pull_url = msg["request_id"], context['base_dir'], context['env'], msg["pull_url"] + request_id, base_dir, env, pull_url, service = msg["request_id"], context['base_dir'], context['env'], msg["pull_url"], context["service"] ex, service_timeout, full_timeout = None, int(context["service"]["timeout"]) + 120, None - command_queue, pull_queue, image_queue, fb_queue = self._command_queue, self._pull_queue, self._image_queue, \ - self._fb_queue + + command_queue, pull_queue, image_queue, fb_queue, mqtt_list = self._command_queue, self._pull_queue, self._image_queue, \ + self._fb_queue, self._mqtt_list image_thread, pull_p = None, None width, height, width_height_3, all_frames, w_2, h_2 = None, None, None, 0, None, None frame_list, frame_index_list = [], [] @@ -235,8 +237,12 @@ class OfflinePullVideoStreamProcess(PullVideoStreamProcess): init_log(base_dir, env) logger.info("开启离线视频拉流进程, requestId:{}", request_id) + #开启mqtt + if service['mqtt']["flag"]==1: + business = service['mqtt']["business"] + mqtt_thread = self.start_PullMqtt(fb_queue, mqtt_list, request_id, analyse_type, context, business) # 开启图片上传线程 - image_thread = self.start_File_upload(fb_queue, context, msg, image_queue, analyse_type,[]) + image_thread = self.start_File_upload(fb_queue, context, msg, image_queue, analyse_type) # 初始化拉流工具类 cv2_init_num, concurrent_frame = 0, 1 @@ -269,7 +275,7 @@ class OfflinePullVideoStreamProcess(PullVideoStreamProcess): width, height, width_height_3, all_frames, w_2, h_2 = build_video_info(pull_url, request_id) continue if pull_queue.full(): - logger.info("pull拉流队列满了:{}, requestId: {}", os.getppid(), request_id) + #logger.info("pull拉流队列满了:{}, requestId: {}", os.getppid(), request_id) if full_timeout is None: full_timeout = time() if time() - full_timeout > 180: @@ -306,7 +312,7 @@ class OfflinePullVideoStreamProcess(PullVideoStreamProcess): ExceptionType.READSTREAM_TIMEOUT_EXCEPTION.value[1]) logger.info("离线拉流线程结束, requestId: {}", request_id) break - frame_list.append(frame) + frame_list.append([frame,mqtt_list]) frame_index_list.append(concurrent_frame) if len(frame_list) >= frame_num: put_queue(pull_queue, (4, (frame_list, frame_index_list, all_frames)), timeout=1, is_ex=True) diff --git a/concurrency/PushVideoStreamProcess.py b/concurrency/PushVideoStreamProcess.py index 724f189..1d79c7d 100644 --- a/concurrency/PushVideoStreamProcess.py +++ b/concurrency/PushVideoStreamProcess.py @@ -23,7 +23,7 @@ from util.Cv2Utils import video_conjuncing, write_or_video, write_ai_video, push from util.ImageUtils import url2Array, add_water_pic from util.LogUtils import init_log -from util.PlotsUtils import draw_painting_joint, filterBox, xywh2xyxy2, xy2xyxy, draw_name_joint, plot_one_box_auto, draw_name_ocr,draw_name_crowd +from util.PlotsUtils import draw_painting_joint, filterBox, xywh2xyxy2, xy2xyxy, draw_name_joint, plot_one_box_auto, draw_name_ocr,draw_name_crowd,draw_transparent_red_polygon from util.QueUtil import get_no_block_queue, put_queue, clear_queue @@ -36,11 +36,10 @@ class PushStreamProcess(Process): # 传参 self._msg, self._push_queue, self._image_queue, self._push_ex_queue, self._hb_queue, self._context = args self._algStatus = False # 默认关闭 - self._algSwitch = self._context['service']['algSwitch'] + self._algSwitch = self._context['service']['algSwitch'] - - #0521: - default_enabled = str(self._msg.get("defaultEnabled", "True")).lower() == "true" + # 0521: + default_enabled = str(self._msg.get("defaultEnabled", "True")).lower() == "true" if default_enabled: print("执行默认程序(defaultEnabled=True)") self._algSwitch = True @@ -131,15 +130,26 @@ class OnPushStreamProcess(PushStreamProcess): if push_r is not None: if push_r[0] == 1: frame_list, frame_index_list, all_frames, draw_config, push_objs = push_r[1] - for i, frame in enumerate(frame_list): + # 处理每1帧 + for i, [frame,mqtt_list] in enumerate(frame_list): + # mqtt传参 + border = None + mqttPares = None + if len(mqtt_list) >= 1: + mqttPares = mqtt_list[0] + if mqttPares[0] == 1: + border = mqttPares[1] pix_dis = int((frame.shape[0]//10)*1.2) # 复制帧用来画图 copy_frame = frame.copy() + if border is not None: + copy_frame = draw_transparent_red_polygon(copy_frame, np.array(border, np.int32),alpha=0.25) det_xywh, thread_p = {}, [] - det_xywh2 = {} + det_xywh2 = {'mqttPares':mqttPares} # 所有问题的矩阵集合 qs_np = None qs_reurn = [] + bp_np = None for det in push_objs[i]: code, det_result = det # 每个单独模型处理 @@ -168,18 +178,21 @@ class OnPushStreamProcess(PushStreamProcess): else: try: # 应对NaN情况 box, score, cls = xywh2xyxy2(qs) + if cls not in allowedList or score < frame_score: + continue + if ModelType.CITY_FIREAREA_MODEL.value[1] == str(code): + # 借score作为points点集 + box.append(qs[-1]) except: continue - if cls not in allowedList or score < frame_score: - continue label_array, color = label_arrays[cls], rainbows[cls] if ModelType.CHANNEL2_MODEL.value[1] == str(code) and cls == 2: rr = t.submit(draw_name_joint, box, copy_frame, draw_config[code]["label_dict"], score, color, font_config, qs[6]) else: - rr = t.submit(draw_painting_joint, box, copy_frame, label_array, - score, color, font_config) + rr = t.submit(draw_painting_joint, box, copy_frame, label_array, score, color, font_config, border=border) + thread_p.append(rr) if det_xywh.get(code) is None: @@ -189,17 +202,24 @@ class OnPushStreamProcess(PushStreamProcess): if cd is None: det_xywh[code][cls] = [[cls, box, score, label_array, color]] else: - det_xywh[code][cls].append([cls, box, score, label_array, color]) + det_xywh[code][cls].append([cls, box, score, label_array, color]) if qs_np is None: - qs_np = np.array([box[0][0], box[0][1], box[1][0], box[1][1], - box[2][0], box[2][1], box[3][0], box[3][1], + qs_np = np.array([box[0][0], box[0][1], box[1][0], box[1][1], + box[2][0], box[2][1], box[3][0], box[3][1], score, cls, code],dtype=np.float32) else: - result_li = np.array([box[0][0], box[0][1], box[1][0], box[1][1], + result_li = np.array([box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], box[2][1], box[3][0], box[3][1], score, cls, code],dtype=np.float32) qs_np = np.row_stack((qs_np, result_li)) + if ModelType.CITY_FIREAREA_MODEL.value[1] == str(code): + if bp_np is None: + bp_np = np.array([box[0][0], box[0][1], box[-1]], dtype=object) + else: + bp_li = np.array([box[0][0], box[0][1], box[-1]], dtype=object) + bp_np = np.row_stack((bp_np, bp_li)) + if logo: frame = add_water_pic(frame, logo, request_id) copy_frame = add_water_pic(copy_frame, logo, request_id) @@ -230,7 +250,7 @@ class OnPushStreamProcess(PushStreamProcess): if picture_similarity: qs_np_tmp = qs_np_id.copy() b = np.zeros(qs_np.shape[0]) - qs_reurn = np.column_stack((qs_np,b)) + qs_reurn = np.column_stack((qs_np,b)) else: qs_reurn = filterBox(qs_np, qs_np_tmp, pix_dis) if picture_similarity: @@ -256,8 +276,14 @@ class OnPushStreamProcess(PushStreamProcess): score = q[8] rainbows, label_arrays = draw_config[code]["rainbows"], draw_config[code]["label_arrays"] label_array, color = label_arrays[cls], rainbows[cls] - box = [(int(q[0]), int(q[1])), (int(q[2]), int(q[3])), - (int(q[4]), int(q[5])), (int(q[6]), int(q[7]))] + box = [(int(q[0]), int(q[1])), (int(q[2]), int(q[3])), + (int(q[4]), int(q[5])), (int(q[6]), int(q[7]))] + if bp_np is not None: + if len(bp_np.shape)==1: + bp_np = bp_np[np.newaxis, ...] + for bp in bp_np: + if np.array_equal(bp[:2], np.array([int(q[0]), int(q[1])])): + box.append(bp[-1]) is_new = False if q[11] == 1: is_new = True @@ -281,7 +307,7 @@ class OnPushStreamProcess(PushStreamProcess): if push_r[0] == 2: logger.info("拉流进程收到控制命令为:{}, requestId: {}",push_r[1] ,request_id) if 'algStart' == push_r[1]: self._algStatus = True;logger.info("算法识别开启, requestId: {}", request_id) - if 'algStop' == push_r[1]: self._algStatus = False;logger.info("算法识别关闭, requestId: {}", request_id) + if 'algStop' == push_r[1]: self._algStatus = False;logger.info("算法识别关闭, requestId: {}", request_id) if 'stop' == push_r[1]: logger.info("停止推流进程, requestId: {}", request_id) break @@ -368,23 +394,33 @@ class OffPushStreamProcess(PushStreamProcess): # [(2, 操作指令)] 指令操作 if push_r[0] == 1: frame_list, frame_index_list, all_frames, draw_config, push_objs = push_r[1] + # 处理每一帧图片 - for i, frame in enumerate(frame_list): + for i, [frame,mqtt_list] in enumerate(frame_list): + # mqtt传参 + border = None + mqttPares = None + if len(mqtt_list) >= 1: + mqttPares = mqtt_list[0] + if mqttPares[0] == 1: + border = mqttPares[1] pix_dis = int((frame.shape[0]//10)*1.2) if frame_index_list[i] % 300 == 0 and frame_index_list[i] <= all_frames: task_process = "%.2f" % (float(frame_index_list[i]) / float(all_frames)) put_queue(hb_queue, {"hb_value": task_process}, timeout=2) # 复制帧用来画图 copy_frame = frame.copy() + if border is not None: + copy_frame = draw_transparent_red_polygon(copy_frame, np.array(border, np.int32),alpha=0.25) # 所有问题记录字典 det_xywh, thread_p = {}, [] - det_xywh2 = {} + det_xywh2 = {'mqttPares':mqttPares} # 所有问题的矩阵集合 qs_np = None qs_reurn = [] + bp_np = None for det in push_objs[i]: code, det_result = det - # 每个单独模型处理 # 模型编号、100帧的所有问题, 检测目标、颜色、文字图片 if len(det_result) > 0: @@ -414,14 +450,16 @@ class OffPushStreamProcess(PushStreamProcess): box, score, cls = xywh2xyxy2(qs) if cls not in allowedList or score < frame_score: continue + if ModelType.CITY_FIREAREA_MODEL.value[1] == str(code): + box.append(qs[-1]) + label_array, color = label_arrays[cls], rainbows[cls] if ModelType.CHANNEL2_MODEL.value[1] == str(code) and cls == 2: rr = t.submit(draw_name_joint, box, copy_frame, draw_config[code]["label_dict"], score, color, font_config, qs[6]) else: - rr = t.submit(draw_painting_joint, box, copy_frame, label_array, score, color, font_config) + rr = t.submit(draw_painting_joint, box, copy_frame, label_array, score, color, font_config, border=border) thread_p.append(rr) - if det_xywh.get(code) is None: det_xywh[code] = {} cd = det_xywh[code].get(cls) @@ -440,6 +478,13 @@ class OffPushStreamProcess(PushStreamProcess): score, cls, code],dtype=np.float32) qs_np = np.row_stack((qs_np, result_li)) + if ModelType.CITY_FIREAREA_MODEL.value[1]== str(code): + if bp_np is None: + bp_np = np.array([box[0][0], box[0][1],box[-1]],dtype=object) + else: + bp_li = np.array([box[0][0], box[0][1],box[-1]],dtype=object) + bp_np = np.row_stack((bp_np, bp_li)) + if logo: frame = add_water_pic(frame, logo, request_id) copy_frame = add_water_pic(copy_frame, logo, request_id) @@ -467,7 +512,7 @@ class OffPushStreamProcess(PushStreamProcess): if picture_similarity: qs_np_tmp = qs_np_id.copy() b = np.zeros(qs_np.shape[0]) - qs_reurn = np.column_stack((qs_np,b)) + qs_reurn = np.column_stack((qs_np,b)) else: qs_reurn = filterBox(qs_np, qs_np_tmp, pix_dis) if picture_similarity: @@ -494,8 +539,14 @@ class OffPushStreamProcess(PushStreamProcess): score = q[8] rainbows, label_arrays = draw_config[code]["rainbows"], draw_config[code]["label_arrays"] label_array, color = label_arrays[cls], rainbows[cls] - box = [(int(q[0]), int(q[1])), (int(q[2]), int(q[3])), - (int(q[4]), int(q[5])), (int(q[6]), int(q[7]))] + box = [(int(q[0]), int(q[1])), (int(q[2]), int(q[3])), + (int(q[4]), int(q[5])), (int(q[6]), int(q[7]))] + if bp_np is not None: + if len(bp_np.shape)==1: + bp_np = bp_np[np.newaxis, ...] + for bp in bp_np: + if np.array_equal(bp[:2], np.array([int(q[0]), int(q[1])])): + box.append(bp[-1]) is_new = False if q[11] == 1: is_new = True @@ -509,7 +560,7 @@ class OffPushStreamProcess(PushStreamProcess): else: det_xywh2[code][cls].append( [cls, box, score, label_array, color, is_new]) - if len(det_xywh2) > 0: + if len(det_xywh2) > 1: put_queue(image_queue, (1, [det_xywh2, frame, frame_index_list[i], all_frames, draw_config["font_config"]])) push_p = push_stream_result.result(timeout=60) ai_video_file = write_ai_video_result.result(timeout=60) @@ -517,7 +568,7 @@ class OffPushStreamProcess(PushStreamProcess): if push_r[0] == 2: logger.info("拉流进程收到控制命令为:{}, requestId: {}",push_r[1] ,request_id) if 'algStart' == push_r[1]: self._algStatus = True;logger.info("算法识别开启, requestId: {}", request_id) - if 'algStop' == push_r[1]: self._algStatus = False;logger.info("算法识别关闭, requestId: {}", request_id) + if 'algStop' == push_r[1]: self._algStatus = False;logger.info("算法识别关闭, requestId: {}", request_id) if 'stop' == push_r[1]: logger.info("停止推流进程, requestId: {}", request_id) break diff --git a/config/mqtt/dsp_test_mqtt.yml b/config/mqtt/dsp_test_mqtt.yml index a8e2310..6ac0150 100644 --- a/config/mqtt/dsp_test_mqtt.yml +++ b/config/mqtt/dsp_test_mqtt.yml @@ -1,10 +1,21 @@ mqtt_flag: true -broker : "58.213.148.44" -port : 1883 -username: "admin" -password: "admin##123" -#topic: "/topic/v1/airportFly/%s/aiDroneData" -topic: "/topic/v1/airportDrone/THJSQ03B2309TPCTD5QV/realTime/data" -# 存储多少条消息到list里 -length: 10 +# 业务0为经纬度定位,业务1为入侵算法开关 +business: 1 +# 经纬度定位 +location: + broker : "58.213.148.44" + port : 1883 + username: "admin" + password: "admin##123" + #topic: "/topic/v1/airportFly/%s/aiDroneData" + topic: "/topic/v1/airportDrone/THJSQ03B2309TPCTD5QV/realTime/data" +# 入侵 +invade: + broker : "192.168.11.8" + port : 2883 + #topic: "/topic/v1/airportFly/%s/aiDroneData" + topic: "test000/topic" +# 存储多少条消息到list里 + +length: 30 \ No newline at end of file diff --git a/config/service/dsp_test_service.yml b/config/service/dsp_test_service.yml index ab0def4..94293b8 100644 --- a/config/service/dsp_test_service.yml +++ b/config/service/dsp_test_service.yml @@ -33,7 +33,8 @@ service: #storage source,0--aliyun,1--minio storage_source: 0 #是否启用mqtt,0--不用,1--启用 - mqtt_flag: 0 + mqtt: + flag: 0 + business: 1 #是否启用alg控制功能 - algSwitch: False - + algSwitch: False \ No newline at end of file diff --git a/enums/ModelTypeEnum.py b/enums/ModelTypeEnum.py index d2adeba..392145c 100644 --- a/enums/ModelTypeEnum.py +++ b/enums/ModelTypeEnum.py @@ -14,6 +14,7 @@ from utilsK.drownUtils import mixDrowing_water_postprocess from utilsK.noParkingUtils import mixNoParking_road_postprocess from utilsK.illParkingUtils import illParking_postprocess from utilsK.pannelpostUtils import pannel_post_process +from utilsK.securitypostUtils import security_post_process from stdc import stdcModel from yolov5 import yolov5Model from p2pNet import p2NnetModel @@ -63,7 +64,7 @@ class ModelType(Enum): "classes": 5, "rainbows": COLOR }, - + 'fiterList':[2], 'Detweights': "../weights/trt/AIlib2/river/yolov5_%s_fp16.engine" % gpuName, 'Segweights': '../weights/trt/AIlib2/river/stdc_360X640_%s_fp16.engine' % gpuName }) @@ -99,10 +100,8 @@ class ModelType(Enum): 'weight':"../weights/trt/AIlib2/forest2/yolov5_%s_fp16.engine"%(gpuName),###检测模型路径 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } }, + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False}, } - - ], 'postFile': { @@ -112,11 +111,10 @@ class ModelType(Enum): "classes": 5, "rainbows": COLOR }, - 'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6,7,8,9] ],###控制哪些检测类别显示、输出 + "score_byClass": {0: 0.25, 1: 0.3, 2: 0.3, 3: 0.3}, + 'fiterList': [5], 'segRegionCnt':2,###分割模型结果需要保留的等值线数目 "pixScale": 1.2, - - }) @@ -162,7 +160,8 @@ class ModelType(Enum): "classes": 10, "rainbows": COLOR }, - 'allowedList':[0,1,2,3,4,5,6,7,8,9,10,11,12,16,17,18,19,20,21,22], + 'score_byClass':{11:0.75,12:0.75}, + 'fiterList': [13,14,15,16,17,18,19,20,21,22], 'Detweights': "../weights/trt/AIlib2/highWay2/yolov5_%s_fp16.engine" % gpuName, 'Segweights': '../weights/trt/AIlib2/highWay2/stdc_360X640_%s_fp16.engine' % gpuName }) @@ -231,7 +230,7 @@ class ModelType(Enum): "classes": 5, "rainbows": COLOR }, - 'Segweights': None + 'Segweights': None, }) ANGLERSWIMMER_MODEL = ("9", "009", "钓鱼游泳模型", 'AnglerSwimmer', lambda device, gpuName: { @@ -345,7 +344,8 @@ class ModelType(Enum): 'function': riverDetSegMixProcess, 'pars': { 'slopeIndex': [1, 3, 4, 7], - 'riverIou': 0.1 + 'riverIou': 0.1, + 'scale': 0.25 } } }, @@ -377,10 +377,11 @@ class ModelType(Enum): 'weight':'../weights/trt/AIlib2/cityMangement3/yolov5_%s_fp16.engine'%(gpuName), 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3,4,5,6,7],'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':True, "score_byClass":{"0":0.8,"1":0.4,"2":0.5,"3":0.5 } } + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':True} }, { - 'weight':'../weights/pth/AIlib2/cityMangement3/dmpr.pth', + 'weight':'../weights/trt/AIlib2/cityMangement3/dmpr_3090.engine', + #'weight':'../weights/pth/AIlib2/cityMangement3/dmpr.pth', 'par':{ 'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.1, 'dmprimg_size':640, 'name':'dmpr' @@ -403,7 +404,7 @@ class ModelType(Enum): "classes": 8, "rainbows": COLOR }, - 'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6,7,8,9] ],###控制哪些检测类别显示、输出 + "score_byClass":{0:0.8, 1:0.4, 2:0.5, 3:0.5}, 'segRegionCnt':2,###分割模型结果需要保留的等值线数目 "pixScale": 1.2, }) @@ -568,10 +569,10 @@ class ModelType(Enum): 'weight':'../weights/trt/AIlib2/channel2/yolov5_%s_fp16.engine'%(gpuName), 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.1,'iou_thres':0.45,'allowedList':list(range(20)),'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{"0":0.7,"1":0.7,"2":0.8,"3":0.6} } + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.1,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False} }, { - 'weight' : '../weights/pth/AIlib2/ocr2/crnn_ch.pth', + 'weight' : '../weights/trt/AIlib2/ocr2/crnn_ch_%s_fp16_192X32.engine'%(gpuName), 'name':'ocr', 'model':ocrModel, 'par':{ @@ -587,7 +588,6 @@ class ModelType(Enum): }, } ], - 'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6]], 'segPar': None, 'postFile': { "name": "post_process", @@ -597,6 +597,8 @@ class ModelType(Enum): "rainbows": COLOR }, 'Segweights': None, + "score_byClass": {0: 0.7, 1: 0.7, 2: 0.8, 3: 0.6} + }) RIVERT_MODEL = ("25", "025", "河道检测模型(T)", 'riverT', lambda device, gpuName: { @@ -642,7 +644,7 @@ class ModelType(Enum): 'weight':"../weights/trt/AIlib2/forestCrowd/yolov5_%s_fp16.engine"%(gpuName),###检测模型路径 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':[0,1,2,3],'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False, "score_byClass":{ "0":0.25,"1":0.25,"2":0.6,"3":0.6,'4':0.6 ,'5':0.6 } }, + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':False,'trtFlag_seg':False}, } @@ -655,7 +657,7 @@ class ModelType(Enum): "classes": 5, "rainbows": COLOR }, - 'detModelpara':[{"id":str(x),"config":{"k1":"v1","k2":"v2"}} for x in [0,1,2,3,4,5,6,7,8,9] ],###控制哪些检测类别显示、输出 + "score_byClass":{0:0.25,1:0.25,2:0.6,3:0.6,4:0.6 ,5:0.6}, 'segRegionCnt':2,###分割模型结果需要保留的等值线数目 "pixScale": 1.2, @@ -703,6 +705,7 @@ class ModelType(Enum): "classes": 10, "rainbows": COLOR }, + 'fiterltList': [11,12,13,14,15,16,17], 'Detweights': "../weights/trt/AIlib2/highWay2T/yolov5_%s_fp16.engine" % gpuName, 'Segweights': '../weights/trt/AIlib2/highWay2T/stdc_360X640_%s_fp16.engine' % gpuName }) @@ -716,7 +719,7 @@ class ModelType(Enum): 'weight':"../weights/trt/AIlib2/smartSite/yolov5_%s_fp16.engine"%(gpuName),###检测模型路径 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':list(range(20)),'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } }, + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False}, } @@ -724,6 +727,7 @@ class ModelType(Enum): 'postFile': { "rainbows": COLOR }, + "score_byClass": {0: 0.25, 1: 0.3, 2: 0.3, 3: 0.3} }) @@ -736,7 +740,7 @@ class ModelType(Enum): 'weight':"../weights/trt/AIlib2/rubbish/yolov5_%s_fp16.engine"%(gpuName),###检测模型路径 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':list(range(20)),'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } }, + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False}, } @@ -744,6 +748,7 @@ class ModelType(Enum): 'postFile': { "rainbows": COLOR }, + "score_byClass": {0: 0.25, 1: 0.3, 2: 0.3, 3: 0.3} }) @@ -756,7 +761,7 @@ class ModelType(Enum): 'weight':"../weights/trt/AIlib2/firework/yolov5_%s_fp16.engine"%(gpuName),###检测模型路径 'name':'yolov5', 'model':yolov5Model, - 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'allowedList':list(range(20)),'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False, "score_byClass":{"0":0.25,"1":0.3,"2":0.3,"3":0.3 } }, + 'par':{ 'half':True,'device':'cuda:0' ,'conf_thres':0.25,'iou_thres':0.45,'segRegionCnt':1, 'trtFlag_det':True,'trtFlag_seg':False }, } @@ -764,7 +769,6 @@ class ModelType(Enum): 'postFile': { "rainbows": COLOR }, - }) TRAFFIC_SPILL_MODEL = ("50", "501", "高速公路抛洒物模型", 'highWaySpill', lambda device, gpuName: { @@ -806,7 +810,7 @@ class ModelType(Enum): "classes": 2, "rainbows": COLOR }, - 'detModelpara': [{"id": str(x), "config": {"k1": "v1", "k2": "v2"}} for x in [0]], + 'fiterList': [1], ###控制哪些检测类别显示、输出 'Detweights': "../weights/trt/AIlib2/highWaySpill/yolov5_%s_fp16.engine" % gpuName, 'Segweights': '../weights/trt/AIlib2/highWaySpill/stdc_360X640_%s_fp16.engine' % gpuName @@ -851,7 +855,7 @@ class ModelType(Enum): "classes": 4, "rainbows": COLOR }, - 'detModelpara': [{"id": str(x), "config": {"k1": "v1", "k2": "v2"}} for x in [0]], + 'fiterList':[1,2,3], ###控制哪些检测类别显示、输出 'Detweights': "../weights/trt/AIlib2/highWayCthc/yolov5_%s_fp16.engine" % gpuName, 'Segweights': '../weights/trt/AIlib2/highWayCthc/stdc_360X640_%s_fp16.engine' % gpuName @@ -867,14 +871,15 @@ class ModelType(Enum): 'name': 'yolov5', 'model': yolov5Model, 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.25, 'iou_thres': 0.45, - 'allowedList': [0,1,2], 'segRegionCnt': 1, 'trtFlag_det': True, - 'trtFlag_seg': False, "score_byClass": {"0": 0.25, "1": 0.3, "2": 0.3, "3": 0.3}}, + 'segRegionCnt': 1, 'trtFlag_det': True, + 'trtFlag_seg': False}, } ], 'postFile': { "rainbows": COLOR }, + 'fiterList':[0] }) @@ -884,13 +889,13 @@ class ModelType(Enum): 'rainbows': COLOR, 'models': [ { - 'trtFlag_det': False, 'weight': '../weights/pth/AIlib2/carplate/plate_yolov5s_v3.jit', 'name': 'yolov5', 'model': yolov5Model, 'par': { + 'trtFlag_det': False, 'device': 'cuda:0', - 'half': False, + 'half': True, 'conf_thres': 0.4, 'iou_thres': 0.45, 'nc': 1, @@ -898,11 +903,11 @@ class ModelType(Enum): }, }, { - 'trtFlag_ocr': False, - 'weight': '../weights/pth/AIlib2/ocr2/crnn_ch.pth', + 'weight' : '../weights/trt/AIlib2/ocr2/crnn_ch_%s_fp16_192X32.engine'%(gpuName), 'name': 'ocr', 'model': ocrModel, 'par': { + 'trtFlag_ocr': True, 'char_file': '../AIlib2/conf/ocr2/benchmark.txt', 'mode': 'ch', 'nc': 3, @@ -926,10 +931,8 @@ class ModelType(Enum): 'name': 'yolov5', 'model': yolov5Model, 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.50, 'iou_thres': 0.45, - 'allowedList': list(range(20)), 'segRegionCnt': 1, 'trtFlag_det': True, - 'trtFlag_seg': False, "score_byClass": {"0": 0.50, "1": 0.3, "2": 0.3, "3": 0.3}}, + 'segRegionCnt': 1, 'trtFlag_det': True,'trtFlag_seg': False}, } - ], 'postFile': { "rainbows": COLOR @@ -947,8 +950,7 @@ class ModelType(Enum): 'name': 'yolov5', 'model': yolov5Model, 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.50, 'iou_thres': 0.45, - 'allowedList': list(range(20)), 'segRegionCnt': 1, 'trtFlag_det': True, - 'trtFlag_seg': False, "score_byClass": {"0": 0.50, "1": 0.3, "2": 0.3, "3": 0.3}}, + 'segRegionCnt': 1, 'trtFlag_det': True, 'trtFlag_seg': False}, } ], @@ -995,8 +997,7 @@ class ModelType(Enum): 'name': 'yolov5', 'model': yolov5Model, 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.50, 'iou_thres': 0.45, - 'allowedList': list(range(20)), 'segRegionCnt': 1, 'trtFlag_det': True, - 'trtFlag_seg': False, "score_byClass": {"0": 0.50, "1": 0.3, "2": 0.3, "3": 0.3}}, + 'segRegionCnt': 1, 'trtFlag_det': True, 'trtFlag_seg': False}, } ], @@ -1016,8 +1017,7 @@ class ModelType(Enum): 'name': 'yolov5', 'model': yolov5Model, 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.25, 'iou_thres': 0.45, - 'allowedList': [0,1,2], 'segRegionCnt': 1, 'trtFlag_det': True, - 'trtFlag_seg': False, "score_byClass": {"0": 0.25, "1": 0.3, "2": 0.3, "3": 0.3}}, + 'segRegionCnt': 1, 'trtFlag_det': True, 'trtFlag_seg': False}, }, { 'trtFlag_det': False, @@ -1042,6 +1042,54 @@ class ModelType(Enum): }], }) + CITY_FIREAREA_MODEL = ("30", "307", "火焰面积模型", 'FireArea', lambda device, gpuName: { + 'device': device, + 'gpu_name': gpuName, + 'labelnames': ["火焰"], + 'seg_nclass': 2, # 分割模型类别数目,默认2类 + 'segRegionCnt': 0, + 'trtFlag_det': True, + 'trtFlag_seg': False, + 'Detweights': "../weights/trt/AIlib2/smogfire/yolov5_%s_fp16.engine" % gpuName, # 0:fire 1:smoke + 'Samweights': "../weights/pth/AIlib2/firearea/sam_vit_b_01ec64.pth", #分割模型 + 'ksize':(7,7), + 'sam_type':'vit_b', + 'slopeIndex': [], + 'segPar': None, + 'postFile': { + "name": "post_process", + "conf_thres": 0.25, + "iou_thres": 0.45, + "classes": 5, + "rainbows": COLOR + }, + 'Segweights': None, + 'fiterList':[1], + "score_byClass": {0: 0.1} + + }) + + CITY_SECURITY_MODEL = ("30", "308", "安防模型", 'SECURITY', lambda device, gpuName: { + 'labelnames': ["带安全帽","安全帽","攀爬","斗殴","未戴安全帽"], + 'postProcess': {'function': security_post_process, 'pars': {'objs': [0,1],'iou':0.25,'unhelmet':4}}, + 'models': + [ + { + 'weight': "../weights/trt/AIlib2/security/yolov5_%s_fp16.engine" % (gpuName), ###检测模型路径 + 'name': 'yolov5', + 'model': yolov5Model, + 'par': {'half': True, 'device': 'cuda:0', 'conf_thres': 0.25, 'iou_thres': 0.45, + 'segRegionCnt': 1, 'trtFlag_det': True, 'trtFlag_seg': False}, + } + + ], + 'postFile': { + "rainbows": COLOR + }, + 'fiterList': [0,1], + "score_byClass": {"0": 0.50} + }) + @staticmethod def checkCode(code): for model in ModelType: diff --git a/util/ModelUtils.py b/util/ModelUtils.py index 4dd2f7e..d3e9f7b 100644 --- a/util/ModelUtils.py +++ b/util/ModelUtils.py @@ -27,6 +27,7 @@ import torch import tensorrt as trt from utilsK.jkmUtils import pre_process, post_process, get_return_data from DMPR import DMPRModel +from segment_anything import SamPredictor, sam_model_registry FONT_PATH = "../AIlib2/conf/platech.ttf" @@ -36,6 +37,7 @@ class OneModel: def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None): try: + start = time.time() logger.info("########################加载{}########################, requestId:{}", modeType.value[2], requestId) par = modeType.value[4](str(device), gpu_name) @@ -68,10 +70,11 @@ class OneModel: 'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"), 'iou_thres': postFile["iou_thres"], # 对高速模型进行过滤 - 'allowedList': par['allowedList'] if modeType.value[0] == '3' else [], 'segRegionCnt': par['segRegionCnt'], 'trtFlag_det': par['trtFlag_det'], - 'trtFlag_seg': par['trtFlag_seg'] + 'trtFlag_seg': par['trtFlag_seg'], + 'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None, + 'fiterList': par['fiterList'] if 'fiterList' in par.keys() else [] } model_param = { "model": model, @@ -86,6 +89,7 @@ class OneModel: logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) + logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId) # 纯分类模型 class cityManagementModel: __slots__ = "model_conf" @@ -103,6 +107,8 @@ class cityManagementModel: model_param = { "modelList": modelList, "postProcess": postProcess, + "score_byClass":par['score_byClass'] if 'score_byClass' in par.keys() else None, + "fiterList":par['fiterList'] if 'fiterList' in par.keys() else [], } self.model_conf = (modeType, model_param, allowedList, names, rainbows) except Exception: @@ -111,15 +117,14 @@ class cityManagementModel: ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) def detSeg_demo2(args): model_conf, frame, request_id = args - modelList, postProcess = model_conf[1]['modelList'], model_conf[1]['postProcess'] + modelList, postProcess,score_byClass,fiterList = ( + model_conf[1]['modelList'], model_conf[1]['postProcess'],model_conf[1]['score_byClass'], model_conf[1]['fiterList']) try: - result = [[ None, None, AI_process_N([frame], modelList, postProcess)[0] ] ] # 为了让返回值适配统一的接口而写的shi + result = [[ None, None, AI_process_N([frame], modelList, postProcess,score_byClass,fiterList)[0] ] ] # 为了让返回值适配统一的接口而写的shi return result except ServiceException as s: raise s except Exception: - # self.num += 1 - # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame) logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) @@ -127,11 +132,6 @@ def detSeg_demo2(args): def model_process(args): model_conf, frame, request_id = args model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4] - # modeType, model_param, allowedList, names, rainbows = model_conf - # segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args - # model_param['digitFont'] = digitFont - # model_param['label_arraylist'] = label_arraylist - # model_param['font_config'] = font_config try: return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'], rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'], @@ -164,7 +164,13 @@ class TwoModel: Detweights = par['Detweights'] with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime: model = runtime.deserialize_cuda_engine(f.read()) - segmodel = None + if modeType == ModelType.CITY_FIREAREA_MODEL: + sam = sam_model_registry[par['sam_type']](checkpoint=par['Samweights']) + sam.to(device=device) + segmodel = SamPredictor(sam) + else: + segmodel = None + postFile = par['postFile'] conf_thres = postFile["conf_thres"] iou_thres = postFile["iou_thres"] @@ -178,7 +184,10 @@ class TwoModel: "conf_thres": conf_thres, "iou_thres": iou_thres, "trtFlag_det": par['trtFlag_det'], - "otc": otc + "otc": otc, + "ksize":par['ksize'] if 'ksize' in par.keys() else None, + "score_byClass": par['score_byClass'] if 'score_byClass' in par.keys() else None, + "fiterList": par['fiterList'] if 'fiterList' in par.keys() else [] } self.model_conf = (modeType, model_param, allowedList, names, rainbows) except Exception: @@ -186,16 +195,15 @@ class TwoModel: raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId) - - def forest_process(args): model_conf, frame, request_id = args model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4] try: return AI_process_forest([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'], rainbows, model_param['half'], model_param['device'], - model_param['conf_thres'], model_param['iou_thres'], [], font=model_param['digitFont'], - trtFlag_det=model_param['trtFlag_det'], SecNms=model_param['otc']) + model_param['conf_thres'], model_param['iou_thres'],font=model_param['digitFont'], + trtFlag_det=model_param['trtFlag_det'], SecNms=model_param['otc'],ksize = model_param['ksize'], + score_byClass=model_param['score_byClass'],fiterList=model_param['fiterList']) except ServiceException as s: raise s except Exception: @@ -204,7 +212,6 @@ def forest_process(args): logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) - class MultiModel: __slots__ = "model_conf" @@ -223,6 +230,8 @@ class MultiModel: model_param = { "modelList": modelList, "postProcess": postProcess, + "score_byClass": par['score_byClass'] if 'score_byClass' in par.keys() else None, + "fiterList": par['fiterList'] if 'fiterList' in par.keys() else [] } self.model_conf = (modeType, model_param, allowedList, names, rainbows) except Exception: @@ -230,13 +239,13 @@ class MultiModel: raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId) - def channel2_process(args): model_conf, frame, request_id = args - modelList, postProcess = model_conf[1]['modelList'], model_conf[1]['postProcess'] + modelList, postProcess,score_byClass,fiterList = ( + model_conf[1]['modelList'], model_conf[1]['postProcess'],model_conf[1]['score_byClass'], model_conf[1]['fiterList']) try: start = time.time() - result = [[None, None, AI_process_C([frame], modelList, postProcess)[0]]] # 为了让返回值适配统一的接口而写的shi + result = [[None, None, AI_process_C([frame], modelList, postProcess,score_byClass,fiterList)[0]]] # 为了让返回值适配统一的接口而写的shi # print("AI_process_C use time = {}".format(time.time()-start)) return result except ServiceException as s: @@ -245,7 +254,6 @@ def channel2_process(args): logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) - def get_label_arraylist(*args): width, height, names, rainbows = args # line = int(round(0.002 * (height + width) / 2) + 1) @@ -266,8 +274,6 @@ def get_label_arraylist(*args): 'label_location': 'leftTop'} label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH) return digitFont, label_arraylist, (line, text_width, text_height, fontScale, tf) - - # 船只模型 class ShipModel: __slots__ = "model_conf" @@ -293,8 +299,6 @@ class ShipModel: raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId) - - def obb_process(args): model_conf, frame, request_id = args model_param = model_conf[1] @@ -309,7 +313,6 @@ def obb_process(args): logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) - # 车牌分割模型、健康码、行程码分割模型 class IMModel: __slots__ = "model_conf" @@ -333,7 +336,7 @@ class IMModel: new_device = torch.device(par['device']) model = torch.jit.load(par[img_type]['weights']) - logger.info("########################加载 jit 模型成功 成功 ########################, requestId:{}", + logger.info("########################加载 jit 模型成功 成功 ########################, requestId:{}", requestId) self.model_conf = (modeType, allowedList, new_device, model, par, img_type) except Exception: @@ -395,7 +398,6 @@ class CARPLATEModel: raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) - class DENSECROWDCOUNTModel: __slots__ = "model_conf" @@ -763,4 +765,18 @@ MODEL_CONFIG = { None, lambda x: cc_process(x) ), + # 加载火焰面积模型 + ModelType.CITY_FIREAREA_MODEL.value[1]: ( + lambda x, y, r, t, z, h: TwoModel(x, y, r, ModelType.CITY_FIREAREA_MODEL, t, z, h), + ModelType.CITY_FIREAREA_MODEL, + lambda x, y, z: one_label(x, y, z), + lambda x: forest_process(x) + ), + # 加载安防模型 + ModelType.CITY_SECURITY_MODEL.value[1]: ( + lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_SECURITY_MODEL, t, z, h), + ModelType.CITY_SECURITY_MODEL, + lambda x, y, z: one_label(x, y, z), + lambda x: detSeg_demo2(x) + ), } diff --git a/util/PlotsUtils.py b/util/PlotsUtils.py index 943305d..ce28109 100644 --- a/util/PlotsUtils.py +++ b/util/PlotsUtils.py @@ -5,7 +5,7 @@ import unicodedata from loguru import logger FONT_PATH = "../AIlib2/conf/platech.ttf" -zhFont = ImageFont.truetype(FONT_PATH, 20, encoding="utf-8") +zhFont = ImageFont.truetype(FONT_PATH, 20, encoding="utf-8") def get_label_array(color=None, label=None, font=None, fontSize=40, unify=False): if unify: @@ -24,7 +24,6 @@ def get_label_array(color=None, label=None, font=None, fontSize=40, unify=False) im_array = cv2.resize(im_array, (0, 0), fx=scale, fy=scale) return im_array - def get_label_arrays(labelNames, colors, fontSize=40, fontPath="platech.ttf"): font = ImageFont.truetype(fontPath, fontSize, encoding='utf-8') label_arraylist = [get_label_array(colors[i % 20], label_name, font, fontSize) for i, label_name in @@ -50,6 +49,48 @@ def get_label_array_dict(colors, fontSize=40, fontPath="platech.ttf"): zh_dict[code] = arr return zh_dict +def get_label_left(x0,y1,label_array,img): + imh, imw = img.shape[0:2] + lh, lw = label_array.shape[0:2] + # x1 框框左上x位置 + 描述的宽 + # y0 框框左上y位置 - 描述的高 + x1, y0 = x0 + lw, y1 - lh + # 如果y0小于0, 说明超过上边框 + if y0 < 0: + y0 = 0 + # y1等于文字高度 + y1 = y0 + lh + # 如果y1框框的高大于图片高度 + if y1 > imh: + # y1等于图片高度 + y1 = imh + # y0等于y1减去文字高度 + y0 = y1 - lh + # 如果x0小于0 + if x0 < 0: + x0 = 0 + x1 = x0 + lw + if x1 > imw: + x1 = imw + x0 = x1 - lw + return x0,y0,x1,y1 + +def get_label_right(x1,y0,label_array): + lh, lw = label_array.shape[0:2] + # x1 框框右上x位置 + 描述的宽 + # y0 框框右上y位置 - 描述的高 + x0, y1 = x1 - lw, y0 - lh + # 如果y0小于0, 说明超过上边框 + if y0 < 0 or y1 < 0: + y1 = 0 + # y1等于文字高度 + y0 = y1 + lh + # 如果x0小于0 + if x0 < 0 or x1 < 0: + x0 = 0 + x1 = x0 + lw + + return x0,y1,x1,y0 def xywh2xyxy(box): if not isinstance(box[0], (list, tuple, np.ndarray)): @@ -75,42 +116,24 @@ def xy2xyxy(box): box = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)] return box -def draw_painting_joint(box, img, label_array, score=0.5, color=None, config=None, isNew=False): +def draw_painting_joint(box, img, label_array, score=0.5, color=None, config=None, isNew=False, border=None): # 识别问题描述图片的高、宽 - lh, lw = label_array.shape[0:2] # 图片的长度和宽度 - imh, imw = img.shape[0:2] + if border is not None: + border = np.array(border,np.int32) + color,label_array=draw_name_border(box,color,label_array,border) + #img = draw_transparent_red_polygon(img,border,'',alpha=0.1) + + lh, lw = label_array.shape[0:2] + tl = config[0] + if isinstance(box[-1], np.ndarray): + return draw_name_points(img,box,color) + + label = ' %.2f' % score box = xywh2xyxy(box) # 框框左上的位置 x0, y1 = box[0][0], box[0][1] - # if score_location == 'leftTop': - # x0, y1 = box[0][0], box[0][1] - # # 框框左下的位置 - # elif score_location == 'leftBottom': - # x0, y1 = box[3][0], box[3][1] - # else: - # x0, y1 = box[0][0], box[0][1] - # x1 框框左上x位置 + 描述的宽 - # y0 框框左上y位置 - 描述的高 - x1, y0 = x0 + lw, y1 - lh - # 如果y0小于0, 说明超过上边框 - if y0 < 0: - y0 = 0 - # y1等于文字高度 - y1 = y0 + lh - # 如果y1框框的高大于图片高度 - if y1 > imh: - # y1等于图片高度 - y1 = imh - # y0等于y1减去文字高度 - y0 = y1 - lh - # 如果x0小于0 - if x0 < 0: - x0 = 0 - x1 = x0 + lw - if x1 > imw: - x1 = imw - x0 = x1 - lw + x0, y0, x1, y1 = get_label_left(x0, y1, label_array, img) # box_tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1 ''' 1. img(array) 为ndarray类型(可以为cv.imread)直接读取的数据 @@ -120,14 +143,12 @@ def draw_painting_joint(box, img, label_array, score=0.5, color=None, config=Non 5. thickness(int):画线的粗细 6. shift:顶点坐标中小数的位数 ''' - tl = config[0] + img[y0:y1, x0:x1, :] = label_array box1 = np.asarray(box, np.int32) cv2.polylines(img, [box1], True, color, tl) - img[y0:y1, x0:x1, :] = label_array pts_cls = [(x0, y0), (x1, y1)] # 把英文字符score画到类别旁边 # tl = max(int(round(imw / 1920 * 3)), 1) or round(0.002 * (imh + imw) / 2) + 1 - label = ' %.2f' % score # tf = max(tl, 1) # fontScale = float(format(imw / 1920 * 1.1, '.2f')) or tl * 0.33 # fontScale = tl * 0.33 @@ -230,7 +251,6 @@ def draw_name_ocr(box, img, color, line_thickness=2, outfontsize=40): # (color=None, label=None, font=None, fontSize=40, unify=False) label_zh = get_label_array(color, box[0], font, outfontsize) return plot_one_box_auto(box[1], img, color, line_thickness, label_zh) - def filterBox(det0, det1, pix_dis): # det0为 (m1, 11) 矩阵 # det1为 (m2, 12) 矩阵 @@ -276,6 +296,7 @@ def plot_one_box_auto(box, img, color=None, line_thickness=2, label_array=None): # print("省略 :%s, lh:%s, lw:%s"%('+++' * 10, lh, lw)) # 图片的长度和宽度 imh, imw = img.shape[0:2] + points = None box = xy2xyxy(box) # 框框左上的位置 x0, y1 = box[0][0], box[0][1] @@ -316,7 +337,6 @@ def plot_one_box_auto(box, img, color=None, line_thickness=2, label_array=None): return img, box - def draw_name_crowd(dets, img, color, outfontsize=20): font = ImageFont.truetype(FONT_PATH, outfontsize, encoding='utf-8') if len(dets) == 2: @@ -367,4 +387,90 @@ def draw_name_crowd(dets, img, color, outfontsize=20): img[y0:y1, x0:x1, :] = label_arr - return img, dets \ No newline at end of file + return img, dets + +def draw_name_points(img,box,color): + font = ImageFont.truetype(FONT_PATH, 6, encoding='utf-8') + points = box[-1] + arrea = cv2.contourArea(points) + label = '火焰' + arealabel = '面积:%s' % f"{arrea:.1e}" + label_array_area = get_label_array(color, arealabel, font, 10) + label_array = get_label_array(color, label, font, 10) + lh_area, lw_area = label_array_area.shape[0:2] + box = box[:4] + # 框框左上的位置 + x0, y1 = box[0][0], max(box[0][1] - lh_area - 3, 0) + x1, y0 = box[1][0], box[1][1] + x0_label, y0_label, x1_label, y1_label = get_label_left(x0, y1, label_array, img) + x0_area, y0_area, x1_area, y1_area = get_label_right(x1, y0, label_array_area) + img[y0_label:y1_label, x0_label:x1_label, :] = label_array + img[y0_area:y1_area, x0_area:x1_area, :] = label_array_area + # cv2.drawContours(img, points, -1, color, tl) + cv2.polylines(img, [points], False, color, 2) + if lw_area < box[1][0] - box[0][0]: + box = [(x0, y1), (x1, y1), (x1, box[2][1]), (x0, box[2][1])] + else: + box = [(x0_label, y1), (x1, y1), (x1, box[2][1]), (x0_label, box[2][1])] + box = np.asarray(box, np.int32) + cv2.polylines(img, [box], True, color, 2) + return img, box + +def draw_name_border(box,color,label_array,border): + box = xywh2xyxy(box[:4]) + cx, cy = int((box[0][0] + box[2][0]) / 2), int((box[0][1] + box[2][1]) / 2) + flag = cv2.pointPolygonTest(border, (int(cx), int(cy)), + False) # 若为False,会找点是否在内,外,或轮廓上 + if flag == 1: + color = [0, 0, 255] + # 纯白色是(255, 255, 255),根据容差定义白色范围 + lower_white = np.array([255 - 30] * 3, dtype=np.uint8) + upper_white = np.array([255, 255, 255], dtype=np.uint8) + # 创建白色区域的掩码(白色区域为True,非白色为False) + white_mask = cv2.inRange(label_array, lower_white, upper_white) + # 创建与原图相同大小的目标颜色图像 + target_img = np.full_like(label_array, color, dtype=np.uint8) + # 先将非白色区域设为目标颜色,再将白色区域覆盖回原图颜色 + label_array = np.where(white_mask[..., None], label_array, target_img) + return color,label_array + +def draw_transparent_red_polygon(img, points, alpha=0.5): + """ + 在图像中指定的多边形区域绘制半透明红色 + + 参数: + image_path: 原始图像路径 + points: 多边形顶点坐标列表,格式为[(x1,y1), (x2,y2), ..., (xn,yn)] + output_path: 输出图像路径 + alpha: 透明度系数,0-1之间,值越小透明度越高 + """ + # 读取原始图像 + if img is None: + raise ValueError(f"无法读取图像") + + # 创建与原图大小相同的透明图层(RGBA格式) + overlay = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) + + # 将点列表转换为适合cv2.fillPoly的格式 + #pts = np.array(points, np.int32) + pts = points.reshape((-1, 1, 2)) + + # 在透明图层上绘制红色多边形(BGR为0,0,255) + # 最后一个通道是Alpha值,控制透明度,黄色rgb + cv2.fillPoly(overlay, [pts], (255, 0, 0, int(alpha * 255))) + + # 将透明图层转换为BGR格式(用于与原图混合) + overlay_bgr = cv2.cvtColor(overlay, cv2.COLOR_RGBA2BGR) + + # 创建掩码,用于提取红色区域 + mask = overlay[:, :, 3] / 255.0 + mask = np.stack([mask] * 3, axis=-1) # 转换为3通道 + + # 混合原图和透明红色区域 + img = img * (1 - mask) + overlay_bgr * mask + img = img.astype(np.uint8) + + # # 保存结果 + # cv2.imwrite(output_path, result) + + return img \ No newline at end of file