hyf-backend/utils/YOLOTracker.py

467 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from loguru import logger
import subprocess as sp
from ultralytics import YOLO
import time, cv2, numpy as np, math
from traceback import format_exc
from DrGraph.utils.pull_push import NetStream
from DrGraph.utils.Helper import *
from DrGraph.utils.Constant import Constant
from zipfile import ZipFile
class YOLOTracker:
def __init__(self, model_path):
"""
初始化YOLOv11追踪器
"""
self.model = YOLO(model_path)
self.tracking_config = {
"tracker": "appIOs/configs/yolo11/bytetrack.yaml", # "/home/thsw/jcq/projects/yolov11/ultralytics-main/ultralytics/cfg/trackers/bytetrack.yaml",
"conf": 0.25,
"iou": 0.45,
"persist": True,
"verbose": False
}
self.frame_count = 0
self.processing_time = 0
def process_frame(self, frame):
"""
处理单帧图像,进行目标检测和追踪
"""
start_time = time.time()
try:
# 执行YOLOv11目标检测和追踪
results = self.model.track(
source=frame,
**self.tracking_config
)
# 获取第一个结果(因为只处理单张图片)
result = results[0]
# 绘制检测结果
processed_frame = result.plot()
# 计算处理时间
self.processing_time = (time.time() - start_time) * 1000 # 转换为毫秒
self.frame_count += 1
# 打印检测信息(可选)
if self.frame_count % 100 == 0:
self._print_detection_info(result)
return processed_frame, result
except Exception as e:
logger.error("YOLO处理异常: {}", format_exc())
return frame, None
def _print_detection_info(self, result):
"""
打印检测信息
"""
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
detection_count = len(boxes)
unique_ids = set()
for box in boxes:
if box.id is not None:
unique_ids.add(int(box.id[0]))
logger.info(f"{self.frame_count}: 检测到 {detection_count} 个目标, 追踪ID数: {len(unique_ids)}, 处理时间: {self.processing_time:.2f}ms")
else:
logger.info(f"{self.frame_count}: 未检测到目标, 处理时间: {self.processing_time:.2f}ms")
class YOLOTrackerManager:
def __init__(self, model_path, pull_url, push_url, request_id):
self.pull_url = pull_url
self.push_url = push_url
self.request_id = request_id
self.tracker = YOLOTracker(model_path)
self.stream = None
self.videoStream = None
self.videoType = Constant.INPUT_NONE
self.localFile = ''
self.localPath = ''
self.localFiles = []
self._currentFrame = None
self.totalFrames = 0
self.frameChanged = False
def _stop(self):
if self.videoStream is not None:
self.videoStream.release()
self.videoStream = None
if self.stream is not None:
self.stream.clear_pull_p(self.stream.pull_p, self.request_id)
self.stream = None
self.localFile = ''
self.localPath = ''
self.localFiles = []
self._currentFrame = None
self.totalFrames = 0
self._frameIndex = -1
self.videoType = Constant.INPUT_NONE
self.frameChanged = True
def startLocalFile(self, fileName):
self._stop()
self.localFile = fileName
self._frameIndex = -1
def startLocalDir(self, dirName):
self._stop()
self.localPath = dirName
self.localFiles = [os.path.join(dirName, f) for f in os.listdir(dirName) if f.endswith(('.jpg', '.jpeg', '.png'))]
self.totalFrames = len(self.localFiles)
Helper.App.progressMax = self.totalFrames
self.localFiles.sort()
logger.info("本地目录打开: {}, 总帧数: {}", dirName, self.totalFrames)
self._frameIndex = 0
def startLabelledZip(self, labelledPath, categoryPath):
self._stop()
self.localPath = labelledPath
localFiles = ZipFile(labelledPath).namelist()
_, self.totalFrames = Helper.getYoloLabellingInfo(categoryPath, localFiles, '')
imagePath = categoryPath + 'images/'
self.localFiles = [file for file in localFiles if imagePath in file]
logger.info(f"标注压缩文件{labelledPath}{categoryPath}集共有{self.totalFrames}帧, 有效帧数: {len(self.localFiles)}")
self._frameIndex = 0
Helper.App.progressMax = self.totalFrames
def startUsbCamera(self, index = 0):
self._stop()
self.videoStream = cv2.VideoCapture(index)
self.videoType = Constant.INPUT_USB_CAMERA
Helper.Sleep(200)
if not self.videoStream.isOpened():
logger.error("无法打开USB摄像头: {}", index)
self.videoType = Constant.INPUT_NONE
return
self.totalFrames = 0x7FFFFFFF
def startLocalVideo(self, fileName):
self._stop()
self.videoStream = cv2.VideoCapture(fileName)
self.videoType = Constant.INPUT_LOCAL_VIDEO
Helper.Sleep(200)
if not self.videoStream.isOpened():
logger.error("无法打开本地视频流: {}", fileName)
self.videoType = Constant.INPUT_NONE
return
try:
total = int(self.videoStream.get(cv2.CAP_PROP_FRAME_COUNT))
except Exception:
total = 0
self.totalFrames = total if total is not None else 0
Helper.App.progressMax = self.totalFrames
logger.info("本地视频打开: {}, 总帧数: {}", fileName, self.totalFrames)
def startPull(self, url = ''):
self._stop()
if len(url) > 0:
self.pull_url = url
logger.info("拉流地址: {}", self.pull_url)
self.stream = NetStream(self.pull_url, self.push_url, self.request_id)
self.stream.prepare_pull()
def getCurrentFrame(self):
if self._currentFrame is None:
self._currentFrame = self.nextFrame()
if self._currentFrame is not None:
return self._currentFrame.copy()
return None
currentFrame = Property_Rw(getCurrentFrame, None)
def setFrameIndex(self, index):
if self.videoStream is None and len(self.localFiles) == 0:
return
if self.videoStream is not None and self.videoType != Constant.INPUT_LOCAL_VIDEO:
return
if index < 0:
index = 0
if index >= self.totalFrames:
index = self.totalFrames - 1
if self.videoStream:
self.videoStream.set(cv2.CAP_PROP_POS_FRAMES, index)
self._frameIndex = index - 1
self._currentFrame = self.nextFrame()
self.frameChanged = True
frameIndex = Property_rW(setFrameIndex, 0)
def getLabels(self):
with ZipFile(self.localPath, 'r') as zip_ref:
content = zip_ref.read(self.localFile)
content = content.decode('utf-8')
return content
return ''
# 取得待分析的图像帧
def getAnalysisFrame(self, nextFlag):
frameChanged = self.frameChanged
self.frameChanged = False
if nextFlag: # 流式媒体
self._currentFrame = self.nextFrame()
self.frameChanged = True
frame = self.currentFrame
return frame.copy() if frame is not None else None, frameChanged
def nextFrame(self):
frame = None
if self.stream:
frame = self.stream.next_pull_frame()
elif self.videoStream:
ret, frame = self.videoStream.read()
self._frameIndex += 1
if not ret:
self._frameIndex -= 1
frame = None
elif len(self.localFiles) > 0:
if self.localPath.endswith('.zip'):
index = -1
for img_file in self.localFiles:
if '/images/' in img_file:
if index == self._frameIndex:
# logger.warning(f'Loading image from zip file: {img_file}')
try:
with ZipFile(self.localPath, 'r') as zip_ref:
image_data = zip_ref.read(img_file)
nparr = np.frombuffer(image_data, np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
self._frameIndex += 1
lable_file = img_file.replace('/images/', '/labels/').replace('.jpg', '.txt').replace('.png', '.txt')
self.localFile = lable_file
except Exception as e:
# logger.error(f"读取压缩文件 {self.localPath} 中的 {img_file} 失败: {e}")
frame = None
break
index += 1
else:
if self._frameIndex < 0:
self._frameIndex = 0
if self._frameIndex >= len(self.localFiles):
self._frameIndex = 0
if self._frameIndex < len(self.localFiles):
frame = cv2.imread(self.localFiles[self._frameIndex])
if frame is None:
logger.error(f"无法读取目标目录 {self.localPath}中下标为 {self._frameIndex} 的视频文件 {self.localFiles[self._frameIndex]}")
self._frameIndex = -1
return
self._frameIndex += 1
elif self.localFile is not None and self.localFile != '':
frame = cv2.imread(self.localFile)
if frame is None:
logger.error("无法读取本地视频文件: {}", self.localFile)
return
if frame is not None:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if self.totalFrames > 0:
Helper.App.progress = self._frameIndex
return frame
def test_yolo11_recognize(self, frame):
processed_frame = self.process_frame_with_yolo(frame, self.request_id)
return processed_frame
def process_frame_with_yolo(self, frame, requestId):
"""
使用YOLOv11处理帧
"""
try:
# 使用YOLO进行目标检测和追踪
processed_frame, detection_result = self.tracker.process_frame(frame)
# 在帧上添加处理信息
fps_info = f"FPS: {1000/max(self.tracker.processing_time, 1):.1f}"
cv2.putText(processed_frame, fps_info, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 添加检测目标数量信息
if detection_result and detection_result.boxes is not None:
obj_count = len(detection_result.boxes)
count_info = f"Objects: {obj_count}"
cv2.putText(processed_frame, count_info, (10, 70),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
return processed_frame
except Exception as e:
logger.error("YOLO处理异常{}, requestId:{}", format_exc(), requestId)
# 如果处理失败,返回原帧
return frame
def get_gray_mask(self, frame):
"""
生成灰度像素的掩码图
灰度像素定义三颜色分量差小于20
"""
# 创建与原图大小相同的掩码图
maskMat = np.zeros(frame.shape[:2], dtype=np.uint8)
# 获取图像的三个颜色通道
b, g, r = cv2.split(frame)
r = r.astype(np.int16)
g = g.astype(np.int16)
b = b.astype(np.int16)
# 计算任意两个颜色分量之间的差值
diff_rg = np.abs(r - g)
is_shadow = (b > r) & (b - r < 40)
diff_rb = np.abs(r - b)
diff_gb = np.abs(g - b)
# 判断条件三颜色分量差都小于20
gray_pixels = (diff_rg < 20 ) & (diff_rb < 20| is_shadow) & (diff_gb < 20)
# 将满足条件的像素在掩码图中设为255白色
maskMat[gray_pixels] = 255
return maskMat
def debugLine(self, line, y_intersect):
x1, y1, x2, y2 = line
length = np.linalg.norm([x2 - x1, y2 - y1])
# 计算线与水平线的夹角(度数)
# 使用atan2计算弧度再转换为度数
angle_rad = math.atan2(y2 - y1, x2 - x1)
angle_deg = math.degrees(angle_rad)
# 调整角度范围到0-180度平面角
if angle_deg < 0:
angle_deg += 180
# angle_deg = min(angle_deg, 180 - angle_deg)
x_intersect = (x2 - x1) * (y_intersect - y1) / (y2 - y1) + x1
return angle_deg, length, x_intersect
def test_highway_recognize(self, frame, debugFlag = False):
processed_frame = frame.copy()
try:
IGNORE_HEIGHT = 100
y_intersect = frame.shape[0] / 2
frame[:IGNORE_HEIGHT, :] = (255, 0, 0)
gray_mask = self.get_gray_mask(frame)
kernel = np.ones((5, 5), np.uint8) # 使用形态学开运算(先腐蚀后膨胀)去除小噪声点
gray_mask = cv2.erode(gray_mask, kernel)
gray_mask = cv2.erode(gray_mask, kernel)
# 过滤掉面积小于10000的区域
contours, _ = cv2.findContours(gray_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建新的掩码图像只保留面积大于等于10000的区域
filtered_mask = np.zeros_like(gray_mask)
for contour in contours:
area = cv2.contourArea(contour)
if area >= 10000: # 填充满足条件的轮廓区域
cv2.fillPoly(filtered_mask, [contour], 255)
gray_mask = filtered_mask # 使用过滤后的掩码替换原来的gray_mask
edges = cv2.Canny(frame, 100, 200) # 边缘检测
road_edges = cv2.bitwise_and(edges, edges, mask=filtered_mask) # 在过滤后的路面区域内进行边缘检测
# 用color_mask过滤原图得到待处理的图
whiteLineMat = cv2.bitwise_and(processed_frame, processed_frame, mask=filtered_mask)
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_RGB2GRAY) # 灰度化
# sobel边缘检测
whiteLineMat = cv2.Sobel(whiteLineMat, cv2.CV_8U, 1, 0, ksize=3)
tempMat = whiteLineMat.copy()
# whiteLineMat = cv2.Canny(whiteLineMat, 100, 200)
lines = cv2.HoughLinesP(tempMat, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB)
# logger.info(f"{lines.shape[0]} lines: ")
# if lines is not None:
# for line in lines:
# x1, y1, x2, y2 = line[0]
# cv2.line(whiteLineMat, (x1, y1), (x2, y2), (255, 0, 0), 2)
# 创建彩色掩码用于叠加(使用绿色标记识别出的路面)
color_mask = cv2.cvtColor(gray_mask, cv2.COLOR_GRAY2RGB)
color_mask[:] = (0, 255, 0) # 设置为绿色
color_mask = cv2.bitwise_and(color_mask, color_mask, mask=filtered_mask)
# 先叠加路面绿色标记,再叠加白色线条红色标记
overlay = cv2.addWeighted(processed_frame, 0.7, color_mask, 0.3, 0)
# # 在road_edges的基础上识别其中的实线
# lines = cv2.HoughLinesP(road_edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
# logger.info(f"{lines.shape[0]} lines: ")
# linesWithAngle = []
# # if lines is not None:
# for index, line in enumerate(lines):
# angle_deg, length, x_intersect = self.debugLine(line[0], y_intersect)
# linesWithAngle.append((line, angle_deg, x_intersect))
# if debugFlag:
# logger.info(f'line {index + 1}: {line}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# linesWithAngle进行聚类算法按夹角分两类即可
# 使用自定义的简单K-means聚类实现
# line_data = np.array([[angle, x_intersect] for line, angle, x_intersect in linesWithAngle])
# if len(line_data) > 0:
# labels = self._simple_kmeans(line_data, n_clusters=2, random_state=2, random_state=0)
# # 输出两类线的数目
# logger.info(f"聚类结果:{np.bincount(labels)}")
# if debugFlag:
# lines0 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 0]
# lines1 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 1]
# # 取得lines0中所有线段并输出日志信息
# for index, line in enumerate(lines0):
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
# logger.info(f'聚类0: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# for index, line in enumerate(lines1):
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
# logger.info(f'聚类1: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# # 保留数量多的类别
# dominant_cluster = np.argmax(np.bincount(labels))
# # 绘制dominant_cluster类别的线
# dominant_lines = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == dominant_cluster]
# for line, angle, x_intersect in dominant_lines:
# cv2.line(overlay, (int(line[0][0]), int(line[0][1])), (int(line[0][2]), int(line[0][3])), (255, 0, 0), 2)
return overlay, color_mask, whiteLineMat # cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) # cv2.cvtColor(road_edges, cv2.COLOR_GRAY2RGB)
except Exception as e:
logger.error("路面识别异常:{}", format_exc())
# 如果处理失败,返回原始帧
return processed_frame
# def _simple_kmeans(self, data, n_clusters=2, max_iter=100, random_state=0):
# """
# 使用K-means算法对数据进行聚类
# 参数:
# data: array-like, 形状为 (n_samples, n_features) 的输入数据
# n_clusters: int, 聚类数量默认为2
# max_iter: int, 最大迭代次数默认为100
# random_state: int, 随机种子用于初始化质心默认为0
# 返回:
# labels: array, 形状为 (n_samples,) 的聚类标签数组
# """
# np.random.seed(random_state)
# # 随机选择初始质心
# centroids_idx = np.random.choice(len(data), size=n_clusters, replace=False)
# centroids = data[centroids_idx].copy()
# # 迭代优化质心位置
# for _ in range(max_iter):
# # 为每个数据点分配最近的质心标签
# labels = np.zeros(len(data), dtype=int)
# for i, point in enumerate(data):
# distancesi=ids - point, ax(centroids - point, axis=1) ce置为- # 情况如果d sfnpcnsy>d9e则置为>180 -9dis ini作为新质心
# new_centroids[c] = data[np.random.choice(len(data))]
# # 检查收敛条件
# if np.allclose(centroids, new_centroids):
# break
# centroids = new_centroids
# return labels