AIlib2/yolov5.py

108 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from models.experimental import attempt_load
import tensorrt as trt
import sys
from segutils.trtUtils import yolov5Trtforward
from utilsK.queRiver import getDetectionsFromPreds,img_pad
from utils.datasets import letterbox
import numpy as np
import torch,time
import os
def score_filter_byClass(pdetections,score_para_2nd):
ret=[]
for det in pdetections:
score,cls = det[4],det[5]
if int(cls) in score_para_2nd.keys():
score_th = score_para_2nd[int(cls)]
elif str(int(cls)) in score_para_2nd.keys():
score_th = score_para_2nd[str(int(cls))]
else:
score_th = 0.7
if score > score_th:
ret.append(det)
return ret
class yolov5Model(object):
def __init__(self, weights=None,par={}):
self.par = par
self.device = par['device']
self.half =par['half']
if weights.endswith('.engine'):
self. infer_type ='trt'
elif weights.endswith('.pth') or weights.endswith('.pt') :
self. infer_type ='pth'
elif weights.endswith('.jit'):
self. infer_type ='jit'
else:
print('#########ERROR:',weights,': no registered inference type, exit')
sys.exit(0)
if self.infer_type=='trt':
logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
#print('####load TRT model :%s'%(weights))
elif self.infer_type=='pth':
self.model = attempt_load(weights, map_location=self.device) # load FP32 model
if self.half: self.model.half()
elif self.infer_type=='jit':
assert os.path.exists(weights), "%s not exists"
self.model = torch.jit.load(weights, map_location=self.device) # load FP32 model
if 'score_byClass' in par.keys(): self.score_byClass = par['score_byClass']
else: self.score_byClass = None
print('#########加载模型:',weights,' 类型:',self.infer_type)
def eval(self, image):
t0 = time.time()
if self.infer_type != 'jit':
img = self.preprocess_image(image)
t1 = time.time()
if self.infer_type == 'trt':
pred = yolov5Trtforward(self.model, img)
else :
pred = self.model(img, augment=False)[0]
else:
pred = self.model(image)
t3 = time.time()
timeOut = 'yolov5 :%.1f (pre-process:%.1f, ) ' % (self.get_ms(t3, t0), self.get_ms(t3, t0))
return pred, timeOut
t2=time.time()
if 'ovlap_thres_crossCategory' in self.par.keys():
ovlap_thres = self.par['ovlap_thres_crossCategory']
else:
ovlap_thres = None
p_result, timeOut = getDetectionsFromPreds(pred,img,image,conf_thres=self.par['conf_thres'],iou_thres=self.par['iou_thres'],ovlap_thres=ovlap_thres,padInfos=self.padInfos)
if self.score_byClass:
p_result[2] = score_filter_byClass(p_result[2],self.score_byClass)
t3=time.time()
timeOut = 'yolov5 :%.1f (pre-process:%.1f, inference:%.1f, post-process:%.1f) '%( self.get_ms(t3,t0) , self.get_ms(t1,t0) , self.get_ms(t2,t1) , self.get_ms(t3,t2) )
return p_result[2], timeOut
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def preprocess_image(self,image):
if self.infer_type=='trt':
img, padInfos = img_pad( image , size=(640,640,3)) ;img = [img]
self.padInfos =padInfos
else:
img = [letterbox(x, 640, auto=True, stride=32)[0] for x in [image]];
self.padInfos=None
# Stack
img = np.stack(img, 0)
# Convert
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0
return img