You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

581 lines
22KB

  1. import torch
  2. import argparse
  3. import sys,os
  4. sys.path.extend(['segutils'])
  5. from core.models.bisenet import BiSeNet
  6. from model_stages import BiSeNet_STDC
  7. from torchvision import transforms
  8. import cv2,glob
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. import time
  12. from pathlib import Path
  13. from trtUtils import TRTModule,segTrtForward,segtrtEval,segPreProcess_image,get_ms
  14. from concurrent.futures import ThreadPoolExecutor
  15. import tensorrt as trt
  16. from copy import deepcopy
  17. import onnx
  18. import numpy as np
  19. import onnxruntime as ort
  20. import cv2
  21. #import pycuda.driver as cuda
  22. class SegModel_BiSeNet(object):
  23. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  24. #self.args = args
  25. self.model = BiSeNet(nclass)
  26. checkpoint = torch.load(weights)
  27. if isinstance(modelsize,list) or isinstance(modelsize,tuple):
  28. self.modelsize = modelsize
  29. else: self.modelsize = (modelsize,modelsize)
  30. self.model.load_state_dict(checkpoint['model'])
  31. self.device = device
  32. self.model= self.model.to(self.device)
  33. '''self.composed_transforms = transforms.Compose([
  34. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  35. transforms.ToTensor()]) '''
  36. self.mean = (0.335, 0.358, 0.332)
  37. self.std = (0.141, 0.138, 0.143)
  38. def eval(self,image):
  39. time0 = time.time()
  40. imageH,imageW,imageC = image.shape
  41. image = self.preprocess_image(image)
  42. time1 = time.time()
  43. self.model.eval()
  44. image = image.to(self.device)
  45. with torch.no_grad():
  46. output = self.model(image)
  47. time2 = time.time()
  48. pred = output.data.cpu().numpy()
  49. pred = np.argmax(pred, axis=1)[0]#得到每行
  50. time3 = time.time()
  51. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  52. time4 = time.time()
  53. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  54. #print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
  55. return pred,outstr
  56. def get_ms(self,t1,t0):
  57. return (t1-t0)*1000.0
  58. def preprocess_image(self,image):
  59. time0 = time.time()
  60. image = cv2.resize(image,self.modelsize)
  61. time0 = time.time()
  62. image = image.astype(np.float32)
  63. image /= 255.0
  64. image[:,:,0] -=self.mean[0]
  65. image[:,:,1] -=self.mean[1]
  66. image[:,:,2] -=self.mean[2]
  67. image[:,:,0] /= self.std[0]
  68. image[:,:,1] /= self.std[1]
  69. image[:,:,2] /= self.std[2]
  70. image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
  71. #image -= self.mean
  72. #image /= self.std
  73. image = np.transpose(image, ( 2, 0, 1))
  74. image = torch.from_numpy(image).float()
  75. image = image.unsqueeze(0)
  76. return image
  77. class SegModel_STDC(object):
  78. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0',modelSize=(360,640)):
  79. #self.args = args
  80. self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=nclass,
  81. use_boundary_2=False, use_boundary_4=False,
  82. use_boundary_8=True, use_boundary_16=False,
  83. use_conv_last=False,modelSize=modelSize)
  84. self.device = device
  85. self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
  86. self.model= self.model.to(self.device)
  87. self.mean = (0.485, 0.456, 0.406)
  88. self.std = (0.229, 0.224, 0.225)
  89. self.modelSize = modelSize
  90. def eval(self,image):
  91. time0 = time.time()
  92. imageH, imageW, _ = image.shape
  93. image = self.RB_convert(image)
  94. img = self.preprocess_image(image)
  95. if self.device != 'cpu':
  96. imgs = img.to(self.device)
  97. else:imgs=img
  98. time1 = time.time()
  99. self.model.eval()
  100. with torch.no_grad():
  101. output = self.model(imgs)
  102. time2 = time.time()
  103. pred = output.data.cpu().numpy()
  104. pred = np.argmax(pred, axis=1)[0]#得到每行
  105. time3 = time.time()
  106. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  107. time4 = time.time()
  108. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  109. return pred,outstr
  110. def get_ms(self,t1,t0):
  111. return (t1-t0)*1000.0
  112. def preprocess_image(self,image):
  113. image = cv2.resize(image, (self.modelSize[1],self.modelSize[0] ), interpolation=cv2.INTER_LINEAR)
  114. image = image.astype(np.float32)
  115. image /= 255.0
  116. image[:, :, 0] -= self.mean[0]
  117. image[:, :, 1] -= self.mean[1]
  118. image[:, :, 2] -= self.mean[2]
  119. image[:, :, 0] /= self.std[0]
  120. image[:, :, 1] /= self.std[1]
  121. image[:, :, 2] /= self.std[2]
  122. image = np.transpose(image, (2, 0, 1))
  123. image = torch.from_numpy(image).float()
  124. image = image.unsqueeze(0)
  125. return image
  126. def RB_convert(self,image):
  127. image_c = image.copy()
  128. image_c[:,:,0] = image[:,:,2]
  129. image_c[:,:,2] = image[:,:,0]
  130. return image_c
  131. def get_largest_contours(contours):
  132. areas = [cv2.contourArea(x) for x in contours]
  133. max_area = max(areas)
  134. max_id = areas.index(max_area)
  135. return max_id
  136. def infer_usage(par):
  137. #par={'modelSize':(inputShape[3],inputShape[2]),'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),'RGB_convert_first':True,
  138. # 'weights':trtFile,'device':device,'max_threads':1,
  139. # 'image_dir':'../../AIdemo2/images/trafficAccident/','out_dir' :'results'}
  140. segmodel = par['segmodel']
  141. image_urls=glob.glob('%s/*'%(par['image_dir']))
  142. out_dir =par['out_dir']
  143. os.makedirs(out_dir,exist_ok=True)
  144. for im,image_url in enumerate(image_urls[0:1]):
  145. #image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
  146. image_array0 = cv2.imread(image_url)
  147. H,W,C = image_array0.shape
  148. time_1=time.time()
  149. pred,outstr = segmodel.eval(image_array0 )
  150. binary0 = pred.copy()
  151. time0 = time.time()
  152. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  153. max_id = -1
  154. time1 = time.time()
  155. time2 = time.time()
  156. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  157. time3 = time.time()
  158. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  159. ret = cv2.imwrite(out_url,image_array0)
  160. cv2.imwrite(out_url.replace('.','_mask.'),(pred*50).astype(np.uint8))
  161. time4 = time.time()
  162. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  163. def colorstr(*input):
  164. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  165. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  166. colors = {'black': '\033[30m', # basic colors
  167. 'red': '\033[31m',
  168. 'green': '\033[32m',
  169. 'yellow': '\033[33m',
  170. 'blue': '\033[34m',
  171. 'magenta': '\033[35m',
  172. 'cyan': '\033[36m',
  173. 'white': '\033[37m',
  174. 'bright_black': '\033[90m', # bright colors
  175. 'bright_red': '\033[91m',
  176. 'bright_green': '\033[92m',
  177. 'bright_yellow': '\033[93m',
  178. 'bright_blue': '\033[94m',
  179. 'bright_magenta': '\033[95m',
  180. 'bright_cyan': '\033[96m',
  181. 'bright_white': '\033[97m',
  182. 'end': '\033[0m', # misc
  183. 'bold': '\033[1m',
  184. 'underline': '\033[4m'}
  185. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  186. def file_size(path):
  187. # Return file/dir size (MB)
  188. path = Path(path)
  189. if path.is_file():
  190. return path.stat().st_size / 1E6
  191. elif path.is_dir():
  192. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  193. else:
  194. return 0.0
  195. def toONNX(seg_model,onnxFile,inputShape=(1,3,360,640),device=torch.device('cuda:0'),dynamic=False ):
  196. import onnx
  197. im = torch.rand(inputShape).to(device)
  198. seg_model.eval()
  199. out=seg_model(im)
  200. print('###test model infer example over ####')
  201. train=False
  202. dynamic = False
  203. opset=11
  204. print('####begin to export to onnx')
  205. torch.onnx.export(seg_model, im,onnxFile, opset_version=opset,
  206. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  207. do_constant_folding=not train,
  208. input_names=['images'],
  209. output_names=['output'],
  210. #dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  211. # 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  212. # } if dynamic else None
  213. dynamic_axes={
  214. 'images': {0: 'batch_size', 2: 'in_width', 3: 'int_height'},
  215. 'output': {0: 'batch_size', 2: 'out_width', 3: 'out_height'}} if dynamic else None
  216. )
  217. '''
  218. input_name='images'
  219. output_name='output'
  220. torch.onnx.export(seg_model,
  221. im,
  222. onnxFile,
  223. opset_version=11,
  224. input_names=[input_name],
  225. output_names=[output_name],
  226. dynamic_axes={
  227. input_name: {0: 'batch_size', 2: 'in_width', 3: 'int_height'},
  228. output_name: {0: 'batch_size', 2: 'out_width', 3: 'out_height'}}
  229. )
  230. '''
  231. print('output onnx file:',onnxFile)
  232. def ONNXtoTrt(onnxFile,trtFile):
  233. import tensorrt as trt
  234. #onnx = Path('../weights/BiSeNet/checkpoint.onnx')
  235. #onnxFile = Path('../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx')
  236. time0=time.time()
  237. half=True;verbose=True;workspace=4;prefix=colorstr('TensorRT:')
  238. #f = onnx.with_suffix('.engine') # TensorRT engine file
  239. f=trtFile
  240. logger = trt.Logger(trt.Logger.INFO)
  241. if verbose:
  242. logger.min_severity = trt.Logger.Severity.VERBOSE
  243. builder = trt.Builder(logger)
  244. config = builder.create_builder_config()
  245. config.max_workspace_size = workspace * 1 << 30
  246. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  247. network = builder.create_network(flag)
  248. parser = trt.OnnxParser(network, logger)
  249. if not parser.parse_from_file(str(onnxFile)):
  250. raise RuntimeError(f'failed to load ONNX file: {onnx}')
  251. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  252. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  253. print(f'{prefix} Network Description:')
  254. for inp in inputs:
  255. print(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  256. for out in outputs:
  257. print(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  258. half &= builder.platform_has_fast_fp16
  259. print(f'{prefix} building FP{16 if half else 32} engine in {f}')
  260. if half:
  261. config.set_flag(trt.BuilderFlag.FP16)
  262. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  263. t.write(engine.serialize())
  264. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  265. time1=time.time()
  266. print('output trtfile from ONNX, time:%.4f s ,'%(time1-time0),trtFile)
  267. def ONNX_eval(par):
  268. model_path = par['weights'];
  269. modelSize=par['modelSize']
  270. mean = par['mean']
  271. std = par['std']
  272. image_urls=glob.glob('%s/*'%(par['image_dir'] ))
  273. out_dir = par['out_dir']
  274. # 验证模型合法性
  275. onnx_model = onnx.load(model_path)
  276. onnx.checker.check_model(onnx_model)
  277. # 设置模型session以及输入信息
  278. sess = ort.InferenceSession(str(model_path),providers= ort.get_available_providers())
  279. print('len():',len( sess.get_inputs() ))
  280. input_name1 = sess.get_inputs()[0].name
  281. half = False;device = 'cuda:0'
  282. os.makedirs(out_dir,exist_ok=True)
  283. for im,image_url in enumerate(image_urls[0:1]):
  284. image_array0 = cv2.imread(image_url)
  285. #img=segPreProcess_image(image_array0).to(device)
  286. img=segPreProcess_image(image_array0,modelSize=modelSize,mean=mean,std=std,numpy=True,RGB_convert_first=par['RGB_convert_first'])
  287. #img = cv2.resize(img,(512,512)).transpose(2,0,1)
  288. img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
  289. H,W,C = image_array0.shape
  290. time_1=time.time()
  291. #pred,outstr = segmodel.eval(image_array0 )
  292. print('###line343:',img.shape, os.path.basename(image_url))
  293. print('###line343:img[0,0,10:12,10:12] ',img[0,0,10:12,10:12])
  294. output = sess.run(None, {input_name1: img})
  295. pred =output[0]
  296. #pred = pred.data.cpu().numpy()
  297. pred = np.argmax(pred, axis=1)[0]#得到每行
  298. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  299. print('###line362:',np.max(pred))
  300. outstr='###---###'
  301. binary0 = pred.copy()
  302. time0 = time.time()
  303. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  304. max_id = -1
  305. time1 = time.time()
  306. time2 = time.time()
  307. #cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  308. cv2.drawContours(image_array0,contours,-1,(0,255,255),3)
  309. time3 = time.time()
  310. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  311. ret = cv2.imwrite(out_url,image_array0)
  312. ret = cv2.imwrite(out_url.replace('.jpg','_mask.jpg').replace('.png','_mask.png' ),(pred*50).astype(np.uint8))
  313. time4 = time.time()
  314. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  315. print('outimage:',out_url)
  316. #print(output)
  317. class SegModel_STDC_trt(object):
  318. def __init__(self,weights=None,modelsize=512,std=(0.229, 0.224, 0.225),mean=(0.485, 0.456, 0.406),device='cuda:0'):
  319. logger = trt.Logger(trt.Logger.INFO)
  320. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  321. engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  322. self.model = TRTModule(engine, ["images"], ["output"])
  323. self.mean = mean
  324. self.std = std
  325. self.device = device
  326. self.modelsize = modelsize
  327. def eval(self,image):
  328. time0=time.time()
  329. H,W,C=image.shape
  330. img_input=self.segPreProcess_image(image)
  331. time1=time.time()
  332. pred=self.model(img_input)
  333. time2=time.time()
  334. pred=torch.argmax(pred,dim=1).cpu().numpy()[0]
  335. #pred = np.argmax(pred.cpu().numpy(), axis=1)[0]#得到每行
  336. time3 = time.time()
  337. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  338. time4 = time.time()
  339. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  340. return pred,outstr
  341. def segPreProcess_image(self,image):
  342. image = cv2.resize(image,self.modelsize)
  343. image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
  344. image = image.astype(np.float32)
  345. image /= 255.0
  346. image[:,:,0] -=self.mean[0]
  347. image[:,:,1] -=self.mean[1]
  348. image[:,:,2] -=self.mean[2]
  349. image[:,:,0] /= self.std[0]
  350. image[:,:,1] /= self.std[1]
  351. image[:,:,2] /= self.std[2]
  352. image = np.transpose(image, ( 2, 0, 1))
  353. image = torch.from_numpy(image).float()
  354. image = image.unsqueeze(0)
  355. return image.to(self.device)
  356. def get_ms(self,t1,t0):
  357. return (t1-t0)*1000.0
  358. def EngineInfer_onePic_thread(pars_thread):
  359. engine,image_array0,out_dir,image_url,im ,par= pars_thread[0:6]
  360. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  361. H,W,C = image_array0.shape
  362. time0=time.time()
  363. time1=time.time()
  364. # 运行模型
  365. #pred,segInfoStr=segtrtEval(engine,image_array0,par={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True})
  366. pred,segInfoStr=segtrtEval(engine,image_array0,par=par)
  367. cv2.imwrite(out_url.replace('.','_mask.'),(pred*50).astype(np.uint8))
  368. pred = 1 - pred
  369. time2=time.time()
  370. outstr='###---###'
  371. binary0 = pred.copy()
  372. time3 = time.time()
  373. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  374. max_id = -1
  375. #if len(contours)>0:
  376. # max_id = get_largest_contours(contours)
  377. # binary0[:,:] = 0
  378. # cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  379. time4 = time.time()
  380. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  381. time5 = time.time()
  382. ret = cv2.imwrite(out_url,image_array0)
  383. time6 = time.time()
  384. print('image:%d,%s ,%d*%d, %s,,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,segInfoStr, get_ms(time4,time3),get_ms(time5,time4),get_ms(time5,time0) ))
  385. return 'success'
  386. def EngineInfer(par):
  387. modelSize=par['modelSize'];mean = par['mean'] ;std = par['std'] ;RGB_convert_first=par['RGB_convert_first'];device=par['device']
  388. weights=par['weights']; image_dir=par['image_dir']
  389. max_threads=par['max_threads'];par['numpy']=False
  390. image_urls=glob.glob('%s/*'%(image_dir))
  391. out_dir =par['out_dir']
  392. os.makedirs(out_dir,exist_ok=True)
  393. #trt_model = SegModel_STDC_trt(weights=weights,modelsize=modelSize,std=std,mean=mean,device=device)
  394. logger = trt.Logger(trt.Logger.ERROR)
  395. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  396. engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  397. print('#####load TRT file:',weights,'success #####')
  398. pars_thread=[]
  399. pars_threads=[]
  400. for im,image_url in enumerate(image_urls[0:]):
  401. image_array0 = cv2.imread(image_url)
  402. pars_thread=[engine,image_array0,out_dir,image_url,im,par]
  403. pars_threads.append(pars_thread)
  404. #EngineInfer_onePic_thread(pars_thread)
  405. t1=time.time()
  406. if max_threads==1:
  407. for i in range(len(pars_threads[0:])):
  408. EngineInfer_onePic_thread(pars_threads[i])
  409. '''
  410. pred,segInfoStr=segtrtEval(pars_threads[i][0],pars_threads[i][1],par)
  411. bname=os.path.basename( pars_threads[i][3] )
  412. outurl= os.path.join( out_dir , bname.replace( '.png','_mask.png').replace('.jpg','._mask.jpg') )
  413. ret=cv2.imwrite( outurl,(pred*50).astype(np.uint8))
  414. print(ret,outurl)'''
  415. else:
  416. with ThreadPoolExecutor(max_workers=max_threads) as t:
  417. for result in t.map(EngineInfer_onePic_thread, pars_threads):
  418. tt=result
  419. t2=time.time()
  420. print('All %d images time:%.1f ms, each:%.1f ms , with %d threads'%(len(image_urls),(t2-t1)*1000, (t2-t1)*1000.0/len(image_urls), max_threads) )
  421. if __name__=='__main__':
  422. parser = argparse.ArgumentParser()
  423. parser.add_argument('--weights', type=str, default='stdc_360X640.pth', help='model path(s)')
  424. parser.add_argument('--nclass', type=int, default=2, help='segmodel nclass')
  425. parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
  426. parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
  427. opt = parser.parse_args()
  428. print( opt.weights )
  429. #pthFile = Path('../../../yolov5TRT/weights/river/stdc_360X640.pth')
  430. pthFile = Path(opt.weights)
  431. onnxFile = str(pthFile.with_suffix('.onnx')).replace('360X640', '%dX%d'%( opt.mWidth,opt.mHeight ))
  432. trtFile = onnxFile.replace('.onnx','.engine' )
  433. nclass = opt.nclass; device=torch.device('cuda:0');
  434. '''###BiSeNet
  435. weights = '../weights/BiSeNet/checkpoint.pth';;inputShape =(1, 3, 512,512)
  436. segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
  437. seg_model=segmodel.model
  438. '''
  439. ##STDC net
  440. weights = pthFile
  441. inputShape =(1, 3, opt.mHeight,opt.mWidth)#(bs,channels,height,width)
  442. #inputShape =(1, 3, 360,640)#(bs,channels,height,width)
  443. segmodel = SegModel_STDC(nclass=nclass,weights=weights,modelSize=(inputShape[2],inputShape[3]));
  444. seg_model=segmodel.model
  445. par={'modelSize':(inputShape[3],inputShape[2]),'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),'RGB_convert_first':True,
  446. 'weights':trtFile,'device':device,'max_threads':1,'predResize':True,
  447. 'image_dir':'../../AIdemo2/images/trafficAccident/','out_dir' :'results'}
  448. par_onnx =deepcopy( par)
  449. par_onnx['weights']=onnxFile
  450. par_pth =deepcopy( par);par_pth['segmodel']=segmodel;
  451. #infer_usage(par_pth)
  452. #toONNX(seg_model,onnxFile,inputShape=inputShape,device=device,dynamic=True)
  453. ONNXtoTrt(onnxFile,trtFile)
  454. #EngineInfer(par)
  455. #ONNX_eval(par_onnx)