You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

568 lines
21KB

  1. import torch
  2. import argparse
  3. import sys,os
  4. sys.path.extend(['segutils'])
  5. from core.models.bisenet import BiSeNet
  6. from model_stages import BiSeNet_STDC
  7. from torchvision import transforms
  8. import cv2,glob
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. import time
  12. from pathlib import Path
  13. from trtUtils import TRTModule,segTrtForward,segtrtEval,segPreProcess_image,get_ms
  14. from concurrent.futures import ThreadPoolExecutor
  15. import tensorrt as trt
  16. #import pycuda.driver as cuda
  17. class SegModel_BiSeNet(object):
  18. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  19. #self.args = args
  20. self.model = BiSeNet(nclass)
  21. checkpoint = torch.load(weights)
  22. if isinstance(modelsize,list) or isinstance(modelsize,tuple):
  23. self.modelsize = modelsize
  24. else: self.modelsize = (modelsize,modelsize)
  25. self.model.load_state_dict(checkpoint['model'])
  26. self.device = device
  27. self.model= self.model.to(self.device)
  28. '''self.composed_transforms = transforms.Compose([
  29. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  30. transforms.ToTensor()]) '''
  31. self.mean = (0.335, 0.358, 0.332)
  32. self.std = (0.141, 0.138, 0.143)
  33. def eval(self,image):
  34. time0 = time.time()
  35. imageH,imageW,imageC = image.shape
  36. image = self.preprocess_image(image)
  37. time1 = time.time()
  38. self.model.eval()
  39. image = image.to(self.device)
  40. with torch.no_grad():
  41. output = self.model(image)
  42. time2 = time.time()
  43. pred = output.data.cpu().numpy()
  44. pred = np.argmax(pred, axis=1)[0]#得到每行
  45. time3 = time.time()
  46. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  47. time4 = time.time()
  48. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  49. #print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
  50. return pred,outstr
  51. def get_ms(self,t1,t0):
  52. return (t1-t0)*1000.0
  53. def preprocess_image(self,image):
  54. time0 = time.time()
  55. image = cv2.resize(image,self.modelsize)
  56. time0 = time.time()
  57. image = image.astype(np.float32)
  58. image /= 255.0
  59. image[:,:,0] -=self.mean[0]
  60. image[:,:,1] -=self.mean[1]
  61. image[:,:,2] -=self.mean[2]
  62. image[:,:,0] /= self.std[0]
  63. image[:,:,1] /= self.std[1]
  64. image[:,:,2] /= self.std[2]
  65. image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
  66. #image -= self.mean
  67. #image /= self.std
  68. image = np.transpose(image, ( 2, 0, 1))
  69. image = torch.from_numpy(image).float()
  70. image = image.unsqueeze(0)
  71. return image
  72. class SegModel_STDC(object):
  73. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  74. #self.args = args
  75. self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=nclass,
  76. use_boundary_2=False, use_boundary_4=False,
  77. use_boundary_8=True, use_boundary_16=False,
  78. use_conv_last=False)
  79. self.device = device
  80. self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
  81. self.model= self.model.to(self.device)
  82. self.mean = (0.485, 0.456, 0.406)
  83. self.std = (0.229, 0.224, 0.225)
  84. def eval(self,image):
  85. time0 = time.time()
  86. imageH, imageW, _ = image.shape
  87. image = self.RB_convert(image)
  88. img = self.preprocess_image(image)
  89. if self.device != 'cpu':
  90. imgs = img.to(self.device)
  91. else:imgs=img
  92. time1 = time.time()
  93. self.model.eval()
  94. with torch.no_grad():
  95. output = self.model(imgs)
  96. time2 = time.time()
  97. pred = output.data.cpu().numpy()
  98. pred = np.argmax(pred, axis=1)[0]#得到每行
  99. time3 = time.time()
  100. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  101. time4 = time.time()
  102. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  103. return pred,outstr
  104. def get_ms(self,t1,t0):
  105. return (t1-t0)*1000.0
  106. def preprocess_image(self,image):
  107. image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR)
  108. image = image.astype(np.float32)
  109. image /= 255.0
  110. image[:, :, 0] -= self.mean[0]
  111. image[:, :, 1] -= self.mean[1]
  112. image[:, :, 2] -= self.mean[2]
  113. image[:, :, 0] /= self.std[0]
  114. image[:, :, 1] /= self.std[1]
  115. image[:, :, 2] /= self.std[2]
  116. image = np.transpose(image, (2, 0, 1))
  117. image = torch.from_numpy(image).float()
  118. image = image.unsqueeze(0)
  119. return image
  120. def RB_convert(self,image):
  121. image_c = image.copy()
  122. image_c[:,:,0] = image[:,:,2]
  123. image_c[:,:,2] = image[:,:,0]
  124. return image_c
  125. def get_largest_contours(contours):
  126. areas = [cv2.contourArea(x) for x in contours]
  127. max_area = max(areas)
  128. max_id = areas.index(max_area)
  129. return max_id
  130. def infer_usage():
  131. image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
  132. nclass = 2
  133. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  134. #weights = '../weights/BiSeNet/checkpoint.pth'
  135. #segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
  136. weights = '../weights/BiSeNet/checkpoint_640X360_epo33.pth'
  137. segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights,modelsize=(640,360))
  138. image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
  139. out_dir ='results/';
  140. os.makedirs(out_dir,exist_ok=True)
  141. for im,image_url in enumerate(image_urls[0:]):
  142. #image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
  143. image_array0 = cv2.imread(image_url)
  144. H,W,C = image_array0.shape
  145. time_1=time.time()
  146. pred,outstr = segmodel.eval(image_array0 )
  147. #plt.figure(1);plt.imshow(pred);
  148. #plt.show()
  149. binary0 = pred.copy()
  150. time0 = time.time()
  151. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  152. max_id = -1
  153. if len(contours)>0:
  154. max_id = get_largest_contours(contours)
  155. binary0[:,:] = 0
  156. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  157. time1 = time.time()
  158. time2 = time.time()
  159. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  160. time3 = time.time()
  161. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  162. ret = cv2.imwrite(out_url,image_array0)
  163. time4 = time.time()
  164. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  165. #print(outstr)
  166. #plt.figure(0);plt.imshow(pred)
  167. #plt.figure(1);plt.imshow(image_array0)
  168. #plt.figure(2);plt.imshow(binary0)
  169. #plt.show()
  170. #print(out_url,ret)
  171. def colorstr(*input):
  172. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  173. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  174. colors = {'black': '\033[30m', # basic colors
  175. 'red': '\033[31m',
  176. 'green': '\033[32m',
  177. 'yellow': '\033[33m',
  178. 'blue': '\033[34m',
  179. 'magenta': '\033[35m',
  180. 'cyan': '\033[36m',
  181. 'white': '\033[37m',
  182. 'bright_black': '\033[90m', # bright colors
  183. 'bright_red': '\033[91m',
  184. 'bright_green': '\033[92m',
  185. 'bright_yellow': '\033[93m',
  186. 'bright_blue': '\033[94m',
  187. 'bright_magenta': '\033[95m',
  188. 'bright_cyan': '\033[96m',
  189. 'bright_white': '\033[97m',
  190. 'end': '\033[0m', # misc
  191. 'bold': '\033[1m',
  192. 'underline': '\033[4m'}
  193. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  194. def file_size(path):
  195. # Return file/dir size (MB)
  196. path = Path(path)
  197. if path.is_file():
  198. return path.stat().st_size / 1E6
  199. elif path.is_dir():
  200. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  201. else:
  202. return 0.0
  203. def toONNX(seg_model,onnxFile,inputShape=(1,3,360,640),device=torch.device('cuda:0')):
  204. print('####begin to export to onnx')
  205. import onnx
  206. im = torch.rand(inputShape).to(device)
  207. seg_model.eval()
  208. out=seg_model(im)
  209. print('###test model infer example####')
  210. train=False
  211. dynamic = False
  212. opset=11
  213. torch.onnx.export(seg_model, im,onnxFile, opset_version=opset,
  214. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  215. do_constant_folding=not train,
  216. input_names=['images'],
  217. output_names=['output'],
  218. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  219. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  220. } if dynamic else None)
  221. print('output onnx file:',onnxFile)
  222. def ONNXtoTrt(onnxFile,trtFile):
  223. import tensorrt as trt
  224. #onnx = Path('../weights/BiSeNet/checkpoint.onnx')
  225. #onnxFile = Path('../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx')
  226. time0=time.time()
  227. half=True;verbose=True;workspace=4;prefix=colorstr('TensorRT:')
  228. #f = onnx.with_suffix('.engine') # TensorRT engine file
  229. f=trtFile
  230. logger = trt.Logger(trt.Logger.INFO)
  231. if verbose:
  232. logger.min_severity = trt.Logger.Severity.VERBOSE
  233. builder = trt.Builder(logger)
  234. config = builder.create_builder_config()
  235. config.max_workspace_size = workspace * 1 << 30
  236. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  237. network = builder.create_network(flag)
  238. parser = trt.OnnxParser(network, logger)
  239. if not parser.parse_from_file(str(onnxFile)):
  240. raise RuntimeError(f'failed to load ONNX file: {onnx}')
  241. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  242. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  243. print(f'{prefix} Network Description:')
  244. for inp in inputs:
  245. print(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  246. for out in outputs:
  247. print(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  248. half &= builder.platform_has_fast_fp16
  249. print(f'{prefix} building FP{16 if half else 32} engine in {f}')
  250. if half:
  251. config.set_flag(trt.BuilderFlag.FP16)
  252. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  253. t.write(engine.serialize())
  254. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  255. time1=time.time()
  256. print('output trtfile from ONNX, time:%.4f s ,'%(time1-time0),trtFile)
  257. def ONNX_eval():
  258. import onnx
  259. import numpy as np
  260. import onnxruntime as ort
  261. import cv2
  262. #model_path = '../weights/BiSeNet/checkpoint.onnx';modelSize=(512,512);mean=(0.335, 0.358, 0.332),std = (0.141, 0.138, 0.143)
  263. model_path = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx';modelSize=(640,360);mean = (0.485, 0.456, 0.406);std = (0.229, 0.224, 0.225)
  264. # 验证模型合法性
  265. onnx_model = onnx.load(model_path)
  266. onnx.checker.check_model(onnx_model)
  267. # 读入图像并调整为输入维度
  268. img = cv2.imread("../../river_demo/images/slope/菜地_20220713_青年河8_4335_1578.jpg")
  269. H,W,C=img.shape
  270. img = cv2.resize(img,modelSize).transpose(2,0,1)
  271. img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
  272. # 设置模型session以及输入信息
  273. sess = ort.InferenceSession(model_path,providers= ort.get_available_providers())
  274. print('len():',len( sess.get_inputs() ))
  275. input_name1 = sess.get_inputs()[0].name
  276. #input_name2 = sess.get_inputs()[1].name
  277. #input_name3 = sess.get_inputs()[2].name
  278. #output = sess.run(None, {input_name1: img, input_name2: img, input_name3: img})
  279. output = sess.run(None, {input_name1: img})
  280. pred = np.argmax(output[0], axis=1)[0]#得到每行
  281. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  282. #plt.imshow(pred);plt.show()
  283. print( 'type:',type(output) , output[0].shape, output[0].dtype )
  284. #weights = Path('../weights/BiSeNet/checkpoint.engine')
  285. half = False;device = 'cuda:0'
  286. image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
  287. #image_urls=glob.glob('../../river_demo/images/slope/*')
  288. image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
  289. #out_dir ='../../river_demo/images/results/'
  290. out_dir ='results'
  291. os.makedirs(out_dir,exist_ok=True)
  292. for im,image_url in enumerate(image_urls[0:]):
  293. image_array0 = cv2.imread(image_url)
  294. #img=segPreProcess_image(image_array0).to(device)
  295. img=segPreProcess_image(image_array0,modelSize=modelSize,mean=mean,std=std,numpy=True)
  296. #img = cv2.resize(img,(512,512)).transpose(2,0,1)
  297. img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
  298. H,W,C = image_array0.shape
  299. time_1=time.time()
  300. #pred,outstr = segmodel.eval(image_array0 )
  301. output = sess.run(None, {input_name1: img})
  302. pred =output[0]
  303. #pred = model(img, augment=False, visualize=False)
  304. #pred = pred.data.cpu().numpy()
  305. pred = np.argmax(pred, axis=1)[0]#得到每行
  306. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  307. outstr='###---###'
  308. binary0 = pred.copy()
  309. time0 = time.time()
  310. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  311. max_id = -1
  312. if len(contours)>0:
  313. max_id = get_largest_contours(contours)
  314. binary0[:,:] = 0
  315. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  316. time1 = time.time()
  317. time2 = time.time()
  318. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  319. time3 = time.time()
  320. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  321. ret = cv2.imwrite(out_url,image_array0)
  322. time4 = time.time()
  323. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  324. print('outimage:',out_url)
  325. #print(output)
  326. class SegModel_STDC_trt(object):
  327. def __init__(self,weights=None,modelsize=512,std=(0.229, 0.224, 0.225),mean=(0.485, 0.456, 0.406),device='cuda:0'):
  328. logger = trt.Logger(trt.Logger.INFO)
  329. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  330. engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  331. self.model = TRTModule(engine, ["images"], ["output"])
  332. self.mean = mean
  333. self.std = std
  334. self.device = device
  335. self.modelsize = modelsize
  336. def eval(self,image):
  337. time0=time.time()
  338. H,W,C=image.shape
  339. img_input=self.segPreProcess_image(image)
  340. time1=time.time()
  341. pred=self.model(img_input)
  342. time2=time.time()
  343. pred=torch.argmax(pred,dim=1).cpu().numpy()[0]
  344. #pred = np.argmax(pred.cpu().numpy(), axis=1)[0]#得到每行
  345. time3 = time.time()
  346. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  347. time4 = time.time()
  348. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  349. return pred,outstr
  350. def segPreProcess_image(self,image):
  351. image = cv2.resize(image,self.modelsize)
  352. image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
  353. image = image.astype(np.float32)
  354. image /= 255.0
  355. image[:,:,0] -=self.mean[0]
  356. image[:,:,1] -=self.mean[1]
  357. image[:,:,2] -=self.mean[2]
  358. image[:,:,0] /= self.std[0]
  359. image[:,:,1] /= self.std[1]
  360. image[:,:,2] /= self.std[2]
  361. image = np.transpose(image, ( 2, 0, 1))
  362. image = torch.from_numpy(image).float()
  363. image = image.unsqueeze(0)
  364. return image.to(self.device)
  365. def get_ms(self,t1,t0):
  366. return (t1-t0)*1000.0
  367. def EngineInfer_onePic_thread(pars_thread):
  368. engine,image_array0,out_dir,image_url,im = pars_thread[0:6]
  369. H,W,C = image_array0.shape
  370. time0=time.time()
  371. time1=time.time()
  372. # 运行模型
  373. pred,segInfoStr=segtrtEval(engine,image_array0,par={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True})
  374. pred = 1 - pred
  375. time2=time.time()
  376. outstr='###---###'
  377. binary0 = pred.copy()
  378. time3 = time.time()
  379. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  380. max_id = -1
  381. #if len(contours)>0:
  382. # max_id = get_largest_contours(contours)
  383. # binary0[:,:] = 0
  384. # cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  385. time4 = time.time()
  386. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  387. time5 = time.time()
  388. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  389. ret = cv2.imwrite(out_url,image_array0)
  390. time6 = time.time()
  391. print('image:%d,%s ,%d*%d, %s,,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,segInfoStr, get_ms(time4,time3),get_ms(time5,time4),get_ms(time5,time0) ))
  392. return 'success'
  393. def EngineInfer(par):
  394. modelSize=par['modelSize'];mean = par['mean'] ;std = par['std'] ;RGB_convert_first=par['RGB_convert_first'];device=par['device']
  395. weights=par['weights']; image_dir=par['image_dir']
  396. max_threads=par['max_threads']
  397. image_urls=glob.glob('%s/*'%(image_dir))
  398. out_dir =par['out_dir']
  399. os.makedirs(out_dir,exist_ok=True)
  400. #trt_model = SegModel_STDC_trt(weights=weights,modelsize=modelSize,std=std,mean=mean,device=device)
  401. logger = trt.Logger(trt.Logger.ERROR)
  402. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  403. engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  404. print('#####load TRT file:',weights,'success #####')
  405. pars_thread=[]
  406. pars_threads=[]
  407. for im,image_url in enumerate(image_urls[0:]):
  408. image_array0 = cv2.imread(image_url)
  409. pars_thread=[engine,image_array0,out_dir,image_url,im]
  410. pars_threads.append(pars_thread)
  411. #EngineInfer_onePic_thread(pars_thread)
  412. t1=time.time()
  413. if max_threads==1:
  414. for i in range(len(pars_threads[0:])):
  415. EngineInfer_onePic_thread(pars_threads[i])
  416. else:
  417. with ThreadPoolExecutor(max_workers=max_threads) as t:
  418. for result in t.map(EngineInfer_onePic_thread, pars_threads):
  419. tt=result
  420. t2=time.time()
  421. print('All %d images time:%.1f ms, each:%.1f ms , with %d threads'%(len(image_urls),(t2-t1)*1000, (t2-t1)*1000.0/len(image_urls), max_threads) )
  422. if __name__=='__main__':
  423. parser = argparse.ArgumentParser()
  424. parser.add_argument('--weights', type=str, default='stdc_360X640.pth', help='model path(s)')
  425. opt = parser.parse_args()
  426. print( opt.weights )
  427. #pthFile = Path('../../../yolov5TRT/weights/river/stdc_360X640.pth')
  428. pthFile = Path(opt.weights)
  429. onnxFile = pthFile.with_suffix('.onnx')
  430. trtFile = onnxFile.with_suffix('.engine')
  431. nclass = 2; device=torch.device('cuda:0');
  432. '''###BiSeNet
  433. weights = '../weights/BiSeNet/checkpoint.pth';;inputShape =(1, 3, 512,512)
  434. segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
  435. seg_model=segmodel.model
  436. '''
  437. ##STDC net
  438. weights = pthFile
  439. segmodel = SegModel_STDC(nclass=nclass,weights=weights);inputShape =(1, 3, 360,640)#(bs,channels,height,width)
  440. seg_model=segmodel.model
  441. par={'modelSize':(inputShape[3],inputShape[2]),'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),'RGB_convert_first':True,
  442. 'weights':trtFile,'device':device,'max_threads':1,
  443. 'image_dir':'../../river_demo/images/road','out_dir' :'results'}
  444. #infer_usage()
  445. toONNX(seg_model,onnxFile,inputShape=inputShape,device=device)
  446. ONNXtoTrt(onnxFile,trtFile)
  447. #EngineInfer(par)
  448. #ONNX_eval()