yolov5-th/detect.py

444 lines
24 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import time
from pathlib import Path
import random,string
import cv2
import torch
import torch.backends.cudnn as cudnn
import base64
import requests
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box,plot_one_box_PIL,draw_painting_joint,get_label_arrays,get_websource,smooth_outline_auto
from utils.get_offline_url import update_websource_offAndLive,platurlToJsonfile, get_websource_fromTxt
from utils.torch_utils import select_device, load_classifier, time_synchronized
import cv2
import queue
import os,json,sys
import numpy as np
from threading import Thread
import datetime,_thread
import subprocess as sp
import time
from PIL import Image, ImageDraw, ImageFont
from segutils.segmodel import SegModel,get_largest_contours
sys.path.extend(['/home/thsw2/WJ/src/yolov5/segutils'])
#from segutils.segWaterBuilding import SegModel,get_largest_contours,illBuildings
from segutils.core.models.bisenet import BiSeNet_MultiOutput
from collections import Counter
import matplotlib.pyplot as plt
platform_query_url='http://47.96.182.154:9051/api/suanfa/getPlatformInfo'
offlineFile='mintors/offlines/doneCodes.txt'
def get_cls(array):
dcs = Counter(array)
keys = list(dcs.keys())
values = list(dcs.values())
max_index = values.index(max(values))
cls = int(keys[max_index])
return cls
##9月3日之后需要新增语义分割模型。
##10月22日 source 改成文件输入
##12月31日支持从platform读取离线视频
# 使用线程锁,防止线程死锁
mutex = _thread.allocate_lock()
# 存图片的队列
frame_queue = queue.Queue()
# 推流的地址前端通过这个地址拉流主机的IP2019是ffmpeg在nginx中设置的端口号
#camera_path='rtmp://58.200.131.2:1935/livetv/cctv1'
camera_path='/data/WJ/data/THexit/vedio/XiYuDaiHe4.MP4'
###lables colors (BGR)#####
rainbows=[
(0,0,255),(0,255,0),(255,0,0),(255,0,255),(255,255,0),(255,127,0),(255,0,127),
(127,255,0),(0,255,127),(0,127,255),(127,0,255),(255,127,255),(255,255,127),
(127,255,255),(0,255,255),(255,127,255),(127,255,255),
(0,127,0),(0,0,127),(0,255,255)
]
def detect(save_img=False):
rtmpUrl = "rtmp://127.0.0.1:1935/live/test"
OutVideoW,OutVideoH,OutVideoFps=int(opt.OutVideoW),int(opt.OutVideoH),int(opt.OutVideoFps)
command=['ffmpeg',
'-y',
#'-re',' ',
'-f', 'rawvideo',
'-vcodec','rawvideo',
'-pix_fmt', 'bgr24',
'-s', "{}x{}".format(OutVideoW,OutVideoH),# 图片分辨率
#'-vcodec','libx264',
#'-b','2500k',
'-r', str(OutVideoFps),# 视频帧率
'-i', '-',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
#'-preset', 'ultrafast',
'-f', 'flv',
rtmpUrl]
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
#save_img = not opt.nosave and not source.endswith('.txt') # save inference images
save_img = not opt.nosave
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))
sourceTxt=source
if source.endswith('.txt'):
#source_list,port_list,streamName_list = get_websource(source)
source_infos = get_websource_fromTxt(source)
else:
#source_list,port_list,streamName_list = [source],[1935],['demo']
source_infos = [{'url':source,'port':1935,'name':'demo' }]
# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Initialize
set_logging()
device = select_device(opt.device)
half = device.type != 'cpu' # half precision only supported on CUDA
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
#print('###'*20,imgsz,stride)
##加载分割模型###
seg_nclass = 2
weights = 'weights/segmentation/BiSeNet/checkpoint.pth'
segmodel = SegModel(nclass=seg_nclass,weights=weights,device=device)
'''nclass = [2,2]
Segmodel = BiSeNet_MultiOutput(nclass)
weights='weights/segmentation/WaterBuilding.pth'
segmodel = SegModel(model=Segmodel,nclass=nclass,weights=weights,device='cuda:0',multiOutput=True)'''
if half:
model.half() # to FP16
# Second-stage classifier
classify = False
if classify:
modelc = load_classifier(name='resnet101', n=2) # initialize
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
if webcam:
# create file pointer
fp_out=open('mintors/%s.txt'%(time.strftime("Start-%Y-%m-%d-%H-%M-%S", time.localtime())) ,'w')
# Set Dataloader
vid_path, vid_writer = None, None
stream_id = 0
platurlToJsonfile(platform_query_url)
while True:
#for isource in range(len(source_list)):
for isource in range(len(source_infos)):
#source , port ,streamName = source_list[isource],port_list[isource],streamName_list[isource]
source , port ,streamName = source_infos[isource]['url'],source_infos[isource]['port'],source_infos[isource]['name']
print('########## detect.py line129 souce informations:',isource, source , port ,streamName,webcam )
Push_Flag = False
#for wang in ['test']:
try:
if webcam:
#view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
print('#########Using web cam#################')
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
# Get names and colors,fp_log,fp_out都是日志文件
fp_log=open('mintors/detection/stream_%s_%d-%s.txt'%(streamName,stream_id,time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) ,'w')
fp_out.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())+ ' rtmp stream-%s-%d starts \n'%(streamName,stream_id) )
fp_out.flush()
problem_image = [[],[],[],[],[]];
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
####
iimage_cnt = 0
names = model.module.names if hasattr(model, 'module') else model.names
EngLish_label = True
if os.path.exists(opt.labelnames):
with open(opt.labelnames,'r') as fp:
namesjson=json.load(fp)
names_fromfile=namesjson['labelnames']
if len(names_fromfile) == len(names):
names = names_fromfile
EngLish_label = False
else:
print('******Warning 文件:%s读取的类别数目与模型中的数目不一致,使用模型的类别********'%(opt.labelnames))
colors = rainbows
label_arraylist = get_label_arrays(names,colors,outfontsize=40)
# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
t00 = time.time();
if webcam:
while True:
if len(command) > 0:
rtmpUrl = "rtmp://127.0.0.1:%s/live/test"%(port)
command[-1] = rtmpUrl
# 管道配置,其中用到管道
print(command)
ppipe = sp.Popen(command, stdin=sp.PIPE)
Push_Flag = True
break
time00=time.time()
for path, img, im0s, vid_cap in dataset:
t0= time_synchronized()
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
timeseg0 = time.time()
if segmodel:
if webcam:
seg_pred,segstr = segmodel.eval(im0s[0] )
#seg_pred = segmodel.eval(im0s[0],outsize=None,smooth_kernel=20)
else:
seg_pred,segstr = segmodel.eval(im0s )
#seg_pred = segmodel.eval(im0s[0],outsize=None,smooth_kernel=20)
timeseg1 = time.time()
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0]
#print('###','line197:',img.shape,opt.augment,opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms)
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
t2 = time_synchronized()
# Apply Classifier
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# Process detections
for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
im0_bak = im0.copy()
else:
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
iimage_cnt += 1
p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if segmodel:
contours, hierarchy = cv2.findContours(seg_pred,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
water = seg_pred.copy();
if len(contours)>0:
max_id = get_largest_contours(contours)
water[:,:]=0
cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
cv2.drawContours(im0,contours,max_id,(0,255,255),3)
else:
water[:,:] = 0
#im0,water = illBuildings(seg_pred,im0)
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
#check weather the box is inside of water area
if segmodel:
det_c = det.clone(); det_c=det_c.cpu().numpy()
#area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )/((x[2]-x[0])*(x[3]-x[1])) for x in det] )
area_factors = np.array([np.sum(water[int(x[1]):int(x[3]), int(x[0]):int(x[2])] )/((x[2]-x[0])*(x[3]-x[1])) for x in det_c] )
#det = det[area_factors>0.1]
det = det[area_factors>0.03]
###联通要求,临时屏蔽掉水生植被
'''if len(det):
clss = det[:,5]
det = det[clss!=2]'''
if len(det)>0:
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}'
if EngLish_label:
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)%20], line_thickness=3)
else:
#im0=plot_one_box_PIL(xyxy, im0, label=label, color=colors[int(cls)%20], line_thickness=3)
im0 = draw_painting_joint(xyxy,im0,label_arraylist[int(cls)],score=conf,color=rainbows[int(cls)%20],line_thickness=None)
###处理问题图片每fpsample帧上报一张图片。以平均得分最大的为最佳图片
if webcam:
problem_image[0].append( det[:,4].mean()); problem_image[1].append(det);
problem_image[2].append(im0);problem_image[3].append(iimage_cnt);problem_image[4].append(im0_bak)
if webcam & (iimage_cnt % opt.fpsample == 0) & (dataset.mode != 'image') & (len(problem_image[0])>0):
best_index = problem_image[0].index(max(problem_image[0]))
best_frame = problem_image[3][ best_index]
img_send = problem_image[2][ best_index]
img_bak = problem_image[4][ best_index]
dets = problem_image[1][best_index]
cls_max = get_cls(dets[:,5].cpu().detach().numpy())
time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
uid=''.join(random.sample(string.ascii_letters + string.digits, 16))
AIUrl='problems/images_tmp/%s_frame-%d-%d_type-%s_%s_s-%s_AI.jpg'%(time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
ORIUrl='problems/images_tmp/%s_frame-%d-%d_type-%s_%s_s-%s_OR.jpg'%(time_str,best_frame,iimage_cnt,names[cls_max],uid,streamName)
cv2.imwrite(AIUrl,img_send); cv2.imwrite(ORIUrl,img_bak)
outstr='%s save images to %s \n'%(time_str, AIUrl)
fp_log.write(outstr )
problem_image=[[],[],[],[],[]]
# Print time (inference + NMS)
t3 = time_synchronized()
# Save results (image with detections)
if save_img:
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path += '.mp4'
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
im0 = cv2.resize(im0,(OutVideoW,OutVideoH))
if dataset.mode == 'stream':
ppipe.stdin.write(im0.tostring())
t4 = time_synchronized()
#outstr='%s Done.read:%.1f ms, infer:%.1f ms, seginfer:%.1f ms,draw:%.1f ms, save:%.1f ms total:%.1f ms \n'%(s,(t1 - t0)*1000, (t2 - t1)*1000,(timeseg1-timeseg0)*1000, (t3 - t2)*1000,(t4-t3)*1000, (t4-t00)*1000)
timestr=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
outstr='%s ,%s,iframe:%d,,read:%.1f ms,copy:%.1f, infer:%.1f ms, detinfer:%.1f ms,draw:%.1f ms, save:%.1f ms total:%.1f ms \n'%(s,timestr,iimage_cnt,(t0 - t00)*1000,(timeseg0-t0)*1000, (t1 - timeseg0)*1000,(t2-t1)*1000, (t3 - t2)*1000,(t4-t3)*1000, (t4-t00)*1000)
if webcam:
if len(det):
fp_log.write(outstr )
else:
print(outstr)
print(segstr)
sys.stdout.flush()
t00=t4
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
print(f'Done. ({time.time() - t0:.3f}s)')
if not webcam:
break;
except Exception as e:
print('#######reading souce:%s ,error :%s:'%(source,e ))
if Push_Flag and webcam:
####source end 推流or视频结束 ###
ppipe.kill()
fp_out.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) +' rtmp stream-%s-%d ends \n'%(streamName,stream_id) );fp_out.flush()
stream_id += 1
time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
if 'off' in streamName:##只有离线视频时,才会写结束文件。
EndUrl='problems/images_tmp/%s_frame-9999-9999_type-结束_9999999999999999_s-%s_AI.jpg'%(time_str,streamName)
img_end=np.zeros((100,100),dtype=np.uint8);cv2.imwrite(EndUrl,img_end)
EndUrl='problems/images_tmp/%s_frame-9999-9999_type-结束_9999999999999999_s-%s_OR.jpg'%(time_str,streamName)
cv2.imwrite(EndUrl,img_end)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
with open( 'mintors/offlines/doneCodes.txt','a+' ) as fp:
fp.write('%s %s\n'%(time_str,streamName ))
#source_infos=update_websource_offAndLive(platform_query_url,sourceTxt,offlineFile)
fp_log.close()
if webcam:
fp_out.write( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())+' GPU server sleep 10s \n' ) ;fp_out.flush()
time.sleep(10)
if not webcam:
break;
###update source (online or offline)
source_infos=update_websource_offAndLive(platform_query_url,sourceTxt,offlineFile)
if len(source_infos)==0:
print('######NO valid source sleep 10s#####')
time.sleep(10)
if webcam:
fp_out.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--labelnames', type=str,default=None, help='labes nams')
parser.add_argument('--fpsample', type=int, default=240, help='fpsample')
parser.add_argument('--OutVideoW', type=int, default=1920, help='out video width size')
parser.add_argument('--OutVideoH', type=int, default=1080, help='out video height size')
parser.add_argument('--OutVideoFps', type=int, default=30, help='out video fps ')
opt = parser.parse_args()
print(opt)
check_requirements(exclude=('pycocotools', 'thop'))
with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect()
strip_optimizer(opt.weights)
else:
detect()