449 lines
17 KiB
Python
449 lines
17 KiB
Python
import torch
|
||
import argparse
|
||
import sys,os
|
||
|
||
from torchvision import transforms
|
||
import cv2,glob
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import time
|
||
from pathlib import Path
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import tensorrt as trt
|
||
|
||
#import pycuda.driver as cuda
|
||
|
||
|
||
def get_largest_contours(contours):
|
||
areas = [cv2.contourArea(x) for x in contours]
|
||
max_area = max(areas)
|
||
max_id = areas.index(max_area)
|
||
|
||
return max_id
|
||
|
||
def infer_usage():
|
||
image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
|
||
nclass = 2
|
||
#weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
|
||
#weights = '../weights/BiSeNet/checkpoint.pth'
|
||
#segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
|
||
|
||
weights = '../weights/BiSeNet/checkpoint_640X360_epo33.pth'
|
||
segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights,modelsize=(640,360))
|
||
|
||
image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
|
||
out_dir ='results/';
|
||
os.makedirs(out_dir,exist_ok=True)
|
||
for im,image_url in enumerate(image_urls[0:]):
|
||
#image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
|
||
image_array0 = cv2.imread(image_url)
|
||
H,W,C = image_array0.shape
|
||
time_1=time.time()
|
||
pred,outstr = segmodel.eval(image_array0 )
|
||
|
||
#plt.figure(1);plt.imshow(pred);
|
||
#plt.show()
|
||
binary0 = pred.copy()
|
||
|
||
|
||
time0 = time.time()
|
||
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||
max_id = -1
|
||
if len(contours)>0:
|
||
max_id = get_largest_contours(contours)
|
||
binary0[:,:] = 0
|
||
cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
|
||
|
||
time1 = time.time()
|
||
|
||
|
||
time2 = time.time()
|
||
|
||
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
|
||
time3 = time.time()
|
||
out_url='%s/%s'%(out_dir,os.path.basename(image_url))
|
||
ret = cv2.imwrite(out_url,image_array0)
|
||
time4 = time.time()
|
||
|
||
print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
|
||
|
||
def colorstr(*input):
|
||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||
colors = {'black': '\033[30m', # basic colors
|
||
'red': '\033[31m',
|
||
'green': '\033[32m',
|
||
'yellow': '\033[33m',
|
||
'blue': '\033[34m',
|
||
'magenta': '\033[35m',
|
||
'cyan': '\033[36m',
|
||
'white': '\033[37m',
|
||
'bright_black': '\033[90m', # bright colors
|
||
'bright_red': '\033[91m',
|
||
'bright_green': '\033[92m',
|
||
'bright_yellow': '\033[93m',
|
||
'bright_blue': '\033[94m',
|
||
'bright_magenta': '\033[95m',
|
||
'bright_cyan': '\033[96m',
|
||
'bright_white': '\033[97m',
|
||
'end': '\033[0m', # misc
|
||
'bold': '\033[1m',
|
||
'underline': '\033[4m'}
|
||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||
def file_size(path):
|
||
# Return file/dir size (MB)
|
||
path = Path(path)
|
||
if path.is_file():
|
||
return path.stat().st_size / 1E6
|
||
elif path.is_dir():
|
||
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
|
||
else:
|
||
return 0.0
|
||
|
||
|
||
def toONNX(seg_model,onnxFile,inputShape=(1,3,360,640),device=torch.device('cuda:0')):
|
||
print('####begin to export to onnx')
|
||
import onnx
|
||
|
||
im = torch.rand(inputShape).to(device)
|
||
seg_model.eval()
|
||
text_for_pred = torch.LongTensor(1, 90).fill_(0).to(device)
|
||
|
||
|
||
out=seg_model(im,text_for_pred)
|
||
print('###test model infer example####')
|
||
train=False
|
||
dynamic = False
|
||
opset=11
|
||
torch.onnx.export(seg_model, (im,text_for_pred),onnxFile, opset_version=opset,
|
||
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
|
||
do_constant_folding=not train,
|
||
input_names=['images'],
|
||
output_names=['output'],
|
||
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
|
||
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
|
||
} if dynamic else None)
|
||
|
||
#torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)
|
||
|
||
|
||
print('output onnx file:',onnxFile)
|
||
def ONNXtoTrt(onnxFile,trtFile,half=True):
|
||
import tensorrt as trt
|
||
#onnx = Path('../weights/BiSeNet/checkpoint.onnx')
|
||
#onnxFile = Path('../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx')
|
||
time0=time.time()
|
||
#half=True;
|
||
verbose=True;workspace=4;prefix=colorstr('TensorRT:')
|
||
#f = onnx.with_suffix('.engine') # TensorRT engine file
|
||
f=trtFile
|
||
logger = trt.Logger(trt.Logger.INFO)
|
||
if verbose:
|
||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||
|
||
builder = trt.Builder(logger)
|
||
config = builder.create_builder_config()
|
||
config.max_workspace_size = workspace * 1 << 30
|
||
|
||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||
network = builder.create_network(flag)
|
||
parser = trt.OnnxParser(network, logger)
|
||
if not parser.parse_from_file(str(onnxFile)):
|
||
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
||
|
||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||
print(f'{prefix} Network Description:')
|
||
for inp in inputs:
|
||
print(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
||
for out in outputs:
|
||
print(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
||
|
||
half &= builder.platform_has_fast_fp16
|
||
print(f'{prefix} building FP{16 if half else 32} engine in {f}')
|
||
if half:
|
||
config.set_flag(trt.BuilderFlag.FP16)
|
||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||
t.write(engine.serialize())
|
||
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||
time1=time.time()
|
||
print('output trtfile from ONNX, time:%.4f s, half: ,'%(time1-time0),trtFile,half)
|
||
def ONNX_eval():
|
||
import onnx
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import cv2
|
||
|
||
#model_path = '../weights/BiSeNet/checkpoint.onnx';modelSize=(512,512);mean=(0.335, 0.358, 0.332),std = (0.141, 0.138, 0.143)
|
||
model_path = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx';modelSize=(640,360);mean = (0.485, 0.456, 0.406);std = (0.229, 0.224, 0.225)
|
||
# 验证模型合法性
|
||
onnx_model = onnx.load(model_path)
|
||
onnx.checker.check_model(onnx_model)
|
||
# 读入图像并调整为输入维度
|
||
img = cv2.imread("../../river_demo/images/slope/菜地_20220713_青年河8_4335_1578.jpg")
|
||
H,W,C=img.shape
|
||
img = cv2.resize(img,modelSize).transpose(2,0,1)
|
||
img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
|
||
# 设置模型session以及输入信息
|
||
sess = ort.InferenceSession(model_path,providers= ort.get_available_providers())
|
||
print('len():',len( sess.get_inputs() ))
|
||
input_name1 = sess.get_inputs()[0].name
|
||
#input_name2 = sess.get_inputs()[1].name
|
||
#input_name3 = sess.get_inputs()[2].name
|
||
|
||
#output = sess.run(None, {input_name1: img, input_name2: img, input_name3: img})
|
||
output = sess.run(None, {input_name1: img})
|
||
pred = np.argmax(output[0], axis=1)[0]#得到每行
|
||
pred = cv2.resize(pred.astype(np.uint8),(W,H))
|
||
#plt.imshow(pred);plt.show()
|
||
print( 'type:',type(output) , output[0].shape, output[0].dtype )
|
||
|
||
#weights = Path('../weights/BiSeNet/checkpoint.engine')
|
||
|
||
half = False;device = 'cuda:0'
|
||
image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
|
||
#image_urls=glob.glob('../../river_demo/images/slope/*')
|
||
image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
|
||
#out_dir ='../../river_demo/images/results/'
|
||
out_dir ='results'
|
||
os.makedirs(out_dir,exist_ok=True)
|
||
|
||
for im,image_url in enumerate(image_urls[0:]):
|
||
image_array0 = cv2.imread(image_url)
|
||
#img=segPreProcess_image(image_array0).to(device)
|
||
img=segPreProcess_image(image_array0,modelSize=modelSize,mean=mean,std=std,numpy=True)
|
||
|
||
#img = cv2.resize(img,(512,512)).transpose(2,0,1)
|
||
img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
|
||
|
||
|
||
H,W,C = image_array0.shape
|
||
time_1=time.time()
|
||
#pred,outstr = segmodel.eval(image_array0 )
|
||
|
||
|
||
output = sess.run(None, {input_name1: img})
|
||
pred =output[0]
|
||
|
||
|
||
|
||
#pred = model(img, augment=False, visualize=False)
|
||
|
||
#pred = pred.data.cpu().numpy()
|
||
pred = np.argmax(pred, axis=1)[0]#得到每行
|
||
pred = cv2.resize(pred.astype(np.uint8),(W,H))
|
||
|
||
outstr='###---###'
|
||
|
||
binary0 = pred.copy()
|
||
|
||
|
||
time0 = time.time()
|
||
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||
max_id = -1
|
||
if len(contours)>0:
|
||
max_id = get_largest_contours(contours)
|
||
binary0[:,:] = 0
|
||
cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
|
||
|
||
time1 = time.time()
|
||
|
||
|
||
time2 = time.time()
|
||
|
||
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
|
||
time3 = time.time()
|
||
out_url='%s/%s'%(out_dir,os.path.basename(image_url))
|
||
ret = cv2.imwrite(out_url,image_array0)
|
||
time4 = time.time()
|
||
|
||
print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
|
||
print('outimage:',out_url)
|
||
|
||
|
||
|
||
def EngineInfer_onePic_thread(pars_thread):
|
||
|
||
|
||
|
||
|
||
engine,image_array0,out_dir,image_url,im = pars_thread[0:6]
|
||
|
||
|
||
H,W,C = image_array0.shape
|
||
time0=time.time()
|
||
|
||
time1=time.time()
|
||
# 运行模型
|
||
|
||
|
||
pred,segInfoStr=segtrtEval(engine,image_array0,par={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True})
|
||
pred = 1 - pred
|
||
time2=time.time()
|
||
|
||
outstr='###---###'
|
||
binary0 = pred.copy()
|
||
time3 = time.time()
|
||
|
||
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||
max_id = -1
|
||
#if len(contours)>0:
|
||
# max_id = get_largest_contours(contours)
|
||
# binary0[:,:] = 0
|
||
# cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
|
||
time4 = time.time()
|
||
|
||
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
|
||
time5 = time.time()
|
||
out_url='%s/%s'%(out_dir,os.path.basename(image_url))
|
||
ret = cv2.imwrite(out_url,image_array0)
|
||
time6 = time.time()
|
||
|
||
print('image:%d,%s ,%d*%d, %s,,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,segInfoStr, get_ms(time4,time3),get_ms(time5,time4),get_ms(time5,time0) ))
|
||
|
||
|
||
return 'success'
|
||
def trt_version():
|
||
return trt.__version__
|
||
def torch_device_from_trt(device):
|
||
if device == trt.TensorLocation.DEVICE:
|
||
return torch.device("cuda")
|
||
elif device == trt.TensorLocation.HOST:
|
||
return torch.device("cpu")
|
||
else:
|
||
return TypeError("%s is not supported by torch" % device)
|
||
|
||
def torch_dtype_from_trt(dtype):
|
||
if dtype == trt.int8:
|
||
return torch.int8
|
||
elif trt_version() >= '7.0' and dtype == trt.bool:
|
||
return torch.bool
|
||
elif dtype == trt.int32:
|
||
return torch.int32
|
||
elif dtype == trt.float16:
|
||
return torch.float16
|
||
elif dtype == trt.float32:
|
||
return torch.float32
|
||
else:
|
||
raise TypeError("%s is not supported by torch" % dtype)
|
||
def TrtForward(engine,inputs,contextFlag=False):
|
||
|
||
t0=time.time()
|
||
#with engine.create_execution_context() as context:
|
||
if not contextFlag: context = engine.create_execution_context()
|
||
else: context=contextFlag
|
||
|
||
input_names=['images'];output_names=['output']
|
||
batch_size = inputs[0].shape[0]
|
||
bindings = [None] * (len(input_names) + len(output_names))
|
||
t1=time.time()
|
||
# 创建输出tensor,并分配内存
|
||
outputs = [None] * len(output_names)
|
||
for i, output_name in enumerate(output_names):
|
||
idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
|
||
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
|
||
shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
|
||
device = torch_device_from_trt(engine.get_location(idx))
|
||
output = torch.empty(size=shape, dtype=dtype, device=device)
|
||
#print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
|
||
outputs[i] = output
|
||
#print('###line65:',output_name,i,idx,dtype,shape)
|
||
bindings[idx] = output.data_ptr()#绑定输出数据指针
|
||
t2=time.time()
|
||
|
||
for i, input_name in enumerate(input_names):
|
||
idx =engine.get_binding_index(input_name)
|
||
bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
|
||
#print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
|
||
t3=time.time()
|
||
context.execute_v2(bindings) # 执行推理
|
||
t4=time.time()
|
||
|
||
|
||
if len(outputs) == 1:
|
||
outputs = outputs[0]
|
||
outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
|
||
return outputs[0],outstr
|
||
|
||
def EngineInfer(par):
|
||
|
||
modelSize=par['modelSize'];mean = par['mean'] ;std = par['std'] ;RGB_convert_first=par['RGB_convert_first'];device=par['device']
|
||
weights=par['weights']; image_dir=par['image_dir']
|
||
max_threads=par['max_threads']
|
||
image_urls=glob.glob('%s/*'%(image_dir))
|
||
out_dir =par['out_dir']
|
||
|
||
os.makedirs(out_dir,exist_ok=True)
|
||
|
||
#trt_model = SegModel_STDC_trt(weights=weights,modelsize=modelSize,std=std,mean=mean,device=device)
|
||
logger = trt.Logger(trt.Logger.ERROR)
|
||
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
||
engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
||
print('#####load TRT file:',weights,'success #####')
|
||
|
||
pars_thread=[]
|
||
pars_threads=[]
|
||
for im,image_url in enumerate(image_urls[0:]):
|
||
image_array0 = cv2.imread(image_url)
|
||
pars_thread=[engine,image_array0,out_dir,image_url,im]
|
||
pars_threads.append(pars_thread)
|
||
#EngineInfer_onePic_thread(pars_thread)
|
||
t1=time.time()
|
||
if max_threads==1:
|
||
for i in range(len(pars_threads[0:])):
|
||
EngineInfer_onePic_thread(pars_threads[i])
|
||
else:
|
||
with ThreadPoolExecutor(max_workers=max_threads) as t:
|
||
for result in t.map(EngineInfer_onePic_thread, pars_threads):
|
||
tt=result
|
||
|
||
t2=time.time()
|
||
print('All %d images time:%.1f ms, each:%.1f ms , with %d threads'%(len(image_urls),(t2-t1)*1000, (t2-t1)*1000.0/len(image_urls), max_threads) )
|
||
|
||
|
||
|
||
if __name__=='__main__':
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--weights', type=str, default='stdc_360X640.pth', help='model path(s)')
|
||
opt = parser.parse_args()
|
||
print( opt.weights )
|
||
#pthFile = Path('../../../yolov5TRT/weights/river/stdc_360X640.pth')
|
||
pthFile = Path(opt.weights)
|
||
onnxFile = pthFile.with_suffix('.onnx')
|
||
trtFile = onnxFile.with_suffix('.engine')
|
||
|
||
nclass = 2; device=torch.device('cuda:0');
|
||
|
||
'''###BiSeNet
|
||
weights = '../weights/BiSeNet/checkpoint.pth';;inputShape =(1, 3, 512,512)
|
||
segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
|
||
seg_model=segmodel.model
|
||
'''
|
||
|
||
##STDC net
|
||
weights = pthFile
|
||
segmodel = SegModel_STDC(nclass=nclass,weights=weights);inputShape =(1, 3, 360,640)#(bs,channels,height,width)
|
||
seg_model=segmodel.model
|
||
|
||
|
||
|
||
|
||
par={'modelSize':(inputShape[3],inputShape[2]),'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),'RGB_convert_first':True,
|
||
'weights':trtFile,'device':device,'max_threads':1,
|
||
'image_dir':'../../river_demo/images/road','out_dir' :'results'}
|
||
|
||
|
||
#infer_usage()
|
||
toONNX(seg_model,onnxFile,inputShape=inputShape,device=device)
|
||
ONNXtoTrt(onnxFile,trtFile)
|
||
#EngineInfer(par)
|
||
#ONNX_eval()
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|