You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ocrTrt.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import torch
  2. import argparse
  3. import sys,os
  4. from torchvision import transforms
  5. import cv2,glob
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import time
  9. from pathlib import Path
  10. from concurrent.futures import ThreadPoolExecutor
  11. import tensorrt as trt
  12. #import pycuda.driver as cuda
  13. def get_largest_contours(contours):
  14. areas = [cv2.contourArea(x) for x in contours]
  15. max_area = max(areas)
  16. max_id = areas.index(max_area)
  17. return max_id
  18. def infer_usage():
  19. image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
  20. nclass = 2
  21. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  22. #weights = '../weights/BiSeNet/checkpoint.pth'
  23. #segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
  24. weights = '../weights/BiSeNet/checkpoint_640X360_epo33.pth'
  25. segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights,modelsize=(640,360))
  26. image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
  27. out_dir ='results/';
  28. os.makedirs(out_dir,exist_ok=True)
  29. for im,image_url in enumerate(image_urls[0:]):
  30. #image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
  31. image_array0 = cv2.imread(image_url)
  32. H,W,C = image_array0.shape
  33. time_1=time.time()
  34. pred,outstr = segmodel.eval(image_array0 )
  35. #plt.figure(1);plt.imshow(pred);
  36. #plt.show()
  37. binary0 = pred.copy()
  38. time0 = time.time()
  39. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  40. max_id = -1
  41. if len(contours)>0:
  42. max_id = get_largest_contours(contours)
  43. binary0[:,:] = 0
  44. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  45. time1 = time.time()
  46. time2 = time.time()
  47. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  48. time3 = time.time()
  49. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  50. ret = cv2.imwrite(out_url,image_array0)
  51. time4 = time.time()
  52. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  53. def colorstr(*input):
  54. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  55. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  56. colors = {'black': '\033[30m', # basic colors
  57. 'red': '\033[31m',
  58. 'green': '\033[32m',
  59. 'yellow': '\033[33m',
  60. 'blue': '\033[34m',
  61. 'magenta': '\033[35m',
  62. 'cyan': '\033[36m',
  63. 'white': '\033[37m',
  64. 'bright_black': '\033[90m', # bright colors
  65. 'bright_red': '\033[91m',
  66. 'bright_green': '\033[92m',
  67. 'bright_yellow': '\033[93m',
  68. 'bright_blue': '\033[94m',
  69. 'bright_magenta': '\033[95m',
  70. 'bright_cyan': '\033[96m',
  71. 'bright_white': '\033[97m',
  72. 'end': '\033[0m', # misc
  73. 'bold': '\033[1m',
  74. 'underline': '\033[4m'}
  75. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  76. def file_size(path):
  77. # Return file/dir size (MB)
  78. path = Path(path)
  79. if path.is_file():
  80. return path.stat().st_size / 1E6
  81. elif path.is_dir():
  82. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  83. else:
  84. return 0.0
  85. def toONNX(seg_model,onnxFile,inputShape=(1,3,360,640),device=torch.device('cuda:0')):
  86. print('####begin to export to onnx')
  87. import onnx
  88. im = torch.rand(inputShape).to(device)
  89. seg_model.eval()
  90. text_for_pred = torch.LongTensor(1, 90).fill_(0).to(device)
  91. out=seg_model(im)
  92. print('###test model infer example####')
  93. train=False
  94. dynamic = False
  95. opset=11
  96. torch.onnx.export(seg_model, (im),onnxFile, opset_version=opset,
  97. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  98. do_constant_folding=not train,
  99. input_names=['images'],
  100. output_names=['output'],
  101. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  102. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  103. } if dynamic else None)
  104. #torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)
  105. print('output onnx file:',onnxFile)
  106. def ONNXtoTrt(onnxFile,trtFile,half=True):
  107. import tensorrt as trt
  108. #onnx = Path('../weights/BiSeNet/checkpoint.onnx')
  109. #onnxFile = Path('../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx')
  110. time0=time.time()
  111. #half=True;
  112. verbose=True;workspace=4;prefix=colorstr('TensorRT:')
  113. #f = onnx.with_suffix('.engine') # TensorRT engine file
  114. f=trtFile
  115. logger = trt.Logger(trt.Logger.INFO)
  116. if verbose:
  117. logger.min_severity = trt.Logger.Severity.VERBOSE
  118. builder = trt.Builder(logger)
  119. config = builder.create_builder_config()
  120. config.max_workspace_size = workspace * 1 << 30
  121. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  122. network = builder.create_network(flag)
  123. parser = trt.OnnxParser(network, logger)
  124. if not parser.parse_from_file(str(onnxFile)):
  125. raise RuntimeError('failed to load ONNX file: %s'%( onnxFile ))
  126. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  127. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  128. print(f'{prefix} Network Description:')
  129. for inp in inputs:
  130. print(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  131. for out in outputs:
  132. print(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  133. half &= builder.platform_has_fast_fp16
  134. print(f'{prefix} building FP{16 if half else 32} engine in {f}')
  135. if half:
  136. config.set_flag(trt.BuilderFlag.FP16)
  137. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  138. t.write(engine.serialize())
  139. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  140. time1=time.time()
  141. print('output trtfile from ONNX, time:%.4f s, half: ,'%(time1-time0),trtFile,half)
  142. def ONNX_eval():
  143. import onnx
  144. import numpy as np
  145. import onnxruntime as ort
  146. import cv2
  147. #model_path = '../weights/BiSeNet/checkpoint.onnx';modelSize=(512,512);mean=(0.335, 0.358, 0.332),std = (0.141, 0.138, 0.143)
  148. model_path = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.onnx';modelSize=(640,360);mean = (0.485, 0.456, 0.406);std = (0.229, 0.224, 0.225)
  149. # 验证模型合法性
  150. onnx_model = onnx.load(model_path)
  151. onnx.checker.check_model(onnx_model)
  152. # 读入图像并调整为输入维度
  153. img = cv2.imread("../../river_demo/images/slope/菜地_20220713_青年河8_4335_1578.jpg")
  154. H,W,C=img.shape
  155. img = cv2.resize(img,modelSize).transpose(2,0,1)
  156. img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
  157. # 设置模型session以及输入信息
  158. sess = ort.InferenceSession(model_path,providers= ort.get_available_providers())
  159. print('len():',len( sess.get_inputs() ))
  160. input_name1 = sess.get_inputs()[0].name
  161. #input_name2 = sess.get_inputs()[1].name
  162. #input_name3 = sess.get_inputs()[2].name
  163. #output = sess.run(None, {input_name1: img, input_name2: img, input_name3: img})
  164. output = sess.run(None, {input_name1: img})
  165. pred = np.argmax(output[0], axis=1)[0]#得到每行
  166. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  167. #plt.imshow(pred);plt.show()
  168. print( 'type:',type(output) , output[0].shape, output[0].dtype )
  169. #weights = Path('../weights/BiSeNet/checkpoint.engine')
  170. half = False;device = 'cuda:0'
  171. image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
  172. #image_urls=glob.glob('../../river_demo/images/slope/*')
  173. image_urls=glob.glob('../../../../data/无人机起飞测试图像/*')
  174. #out_dir ='../../river_demo/images/results/'
  175. out_dir ='results'
  176. os.makedirs(out_dir,exist_ok=True)
  177. for im,image_url in enumerate(image_urls[0:]):
  178. image_array0 = cv2.imread(image_url)
  179. #img=segPreProcess_image(image_array0).to(device)
  180. img=segPreProcess_image(image_array0,modelSize=modelSize,mean=mean,std=std,numpy=True)
  181. #img = cv2.resize(img,(512,512)).transpose(2,0,1)
  182. img = np.array(img)[np.newaxis, :, :, :].astype(np.float32)
  183. H,W,C = image_array0.shape
  184. time_1=time.time()
  185. #pred,outstr = segmodel.eval(image_array0 )
  186. output = sess.run(None, {input_name1: img})
  187. pred =output[0]
  188. #pred = model(img, augment=False, visualize=False)
  189. #pred = pred.data.cpu().numpy()
  190. pred = np.argmax(pred, axis=1)[0]#得到每行
  191. pred = cv2.resize(pred.astype(np.uint8),(W,H))
  192. outstr='###---###'
  193. binary0 = pred.copy()
  194. time0 = time.time()
  195. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  196. max_id = -1
  197. if len(contours)>0:
  198. max_id = get_largest_contours(contours)
  199. binary0[:,:] = 0
  200. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  201. time1 = time.time()
  202. time2 = time.time()
  203. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  204. time3 = time.time()
  205. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  206. ret = cv2.imwrite(out_url,image_array0)
  207. time4 = time.time()
  208. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  209. print('outimage:',out_url)
  210. def EngineInfer_onePic_thread(pars_thread):
  211. engine,image_array0,out_dir,image_url,im = pars_thread[0:6]
  212. H,W,C = image_array0.shape
  213. time0=time.time()
  214. time1=time.time()
  215. # 运行模型
  216. pred,segInfoStr=segtrtEval(engine,image_array0,par={'modelSize':(640,360),'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True})
  217. pred = 1 - pred
  218. time2=time.time()
  219. outstr='###---###'
  220. binary0 = pred.copy()
  221. time3 = time.time()
  222. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  223. max_id = -1
  224. #if len(contours)>0:
  225. # max_id = get_largest_contours(contours)
  226. # binary0[:,:] = 0
  227. # cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  228. time4 = time.time()
  229. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  230. time5 = time.time()
  231. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  232. ret = cv2.imwrite(out_url,image_array0)
  233. time6 = time.time()
  234. print('image:%d,%s ,%d*%d, %s,,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,segInfoStr, get_ms(time4,time3),get_ms(time5,time4),get_ms(time5,time0) ))
  235. return 'success'
  236. def trt_version():
  237. return trt.__version__
  238. def torch_device_from_trt(device):
  239. if device == trt.TensorLocation.DEVICE:
  240. return torch.device("cuda")
  241. elif device == trt.TensorLocation.HOST:
  242. return torch.device("cpu")
  243. else:
  244. return TypeError("%s is not supported by torch" % device)
  245. def torch_dtype_from_trt(dtype):
  246. if dtype == trt.int8:
  247. return torch.int8
  248. elif trt_version() >= '7.0' and dtype == trt.bool:
  249. return torch.bool
  250. elif dtype == trt.int32:
  251. return torch.int32
  252. elif dtype == trt.float16:
  253. return torch.float16
  254. elif dtype == trt.float32:
  255. return torch.float32
  256. else:
  257. raise TypeError("%s is not supported by torch" % dtype)
  258. def TrtForward(engine,inputs,contextFlag=False):
  259. t0=time.time()
  260. #with engine.create_execution_context() as context:
  261. if not contextFlag: context = engine.create_execution_context()
  262. else: context=contextFlag
  263. input_names=['images'];output_names=['output']
  264. batch_size = inputs[0].shape[0]
  265. bindings = [None] * (len(input_names) + len(output_names))
  266. t1=time.time()
  267. # 创建输出tensor,并分配内存
  268. outputs = [None] * len(output_names)
  269. for i, output_name in enumerate(output_names):
  270. idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
  271. dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
  272. shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
  273. device = torch_device_from_trt(engine.get_location(idx))
  274. output = torch.empty(size=shape, dtype=dtype, device=device)
  275. #print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
  276. outputs[i] = output
  277. #print('###line65:',output_name,i,idx,dtype,shape)
  278. bindings[idx] = output.data_ptr()#绑定输出数据指针
  279. t2=time.time()
  280. for i, input_name in enumerate(input_names):
  281. idx =engine.get_binding_index(input_name)
  282. bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
  283. #print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
  284. t3=time.time()
  285. context.execute_v2(bindings) # 执行推理
  286. t4=time.time()
  287. if len(outputs) == 1:
  288. outputs = outputs[0]
  289. outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
  290. return outputs[0],outstr
  291. def EngineInfer(par):
  292. modelSize=par['modelSize'];mean = par['mean'] ;std = par['std'] ;RGB_convert_first=par['RGB_convert_first'];device=par['device']
  293. weights=par['weights']; image_dir=par['image_dir']
  294. max_threads=par['max_threads']
  295. image_urls=glob.glob('%s/*'%(image_dir))
  296. out_dir =par['out_dir']
  297. os.makedirs(out_dir,exist_ok=True)
  298. #trt_model = SegModel_STDC_trt(weights=weights,modelsize=modelSize,std=std,mean=mean,device=device)
  299. logger = trt.Logger(trt.Logger.ERROR)
  300. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  301. engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  302. print('#####load TRT file:',weights,'success #####')
  303. pars_thread=[]
  304. pars_threads=[]
  305. for im,image_url in enumerate(image_urls[0:]):
  306. image_array0 = cv2.imread(image_url)
  307. pars_thread=[engine,image_array0,out_dir,image_url,im]
  308. pars_threads.append(pars_thread)
  309. #EngineInfer_onePic_thread(pars_thread)
  310. t1=time.time()
  311. if max_threads==1:
  312. for i in range(len(pars_threads[0:])):
  313. EngineInfer_onePic_thread(pars_threads[i])
  314. else:
  315. with ThreadPoolExecutor(max_workers=max_threads) as t:
  316. for result in t.map(EngineInfer_onePic_thread, pars_threads):
  317. tt=result
  318. t2=time.time()
  319. print('All %d images time:%.1f ms, each:%.1f ms , with %d threads'%(len(image_urls),(t2-t1)*1000, (t2-t1)*1000.0/len(image_urls), max_threads) )
  320. if __name__=='__main__':
  321. parser = argparse.ArgumentParser()
  322. parser.add_argument('--weights', type=str, default='stdc_360X640.pth', help='model path(s)')
  323. opt = parser.parse_args()
  324. print( opt.weights )
  325. #pthFile = Path('../../../yolov5TRT/weights/river/stdc_360X640.pth')
  326. pthFile = Path(opt.weights)
  327. onnxFile = pthFile.with_suffix('.onnx')
  328. trtFile = onnxFile.with_suffix('.engine')
  329. nclass = 2; device=torch.device('cuda:0');
  330. '''###BiSeNet
  331. weights = '../weights/BiSeNet/checkpoint.pth';;inputShape =(1, 3, 512,512)
  332. segmodel = SegModel_BiSeNet(nclass=nclass,weights=weights)
  333. seg_model=segmodel.model
  334. '''
  335. ##STDC net
  336. weights = pthFile
  337. segmodel = SegModel_STDC(nclass=nclass,weights=weights);inputShape =(1, 3, 360,640)#(bs,channels,height,width)
  338. seg_model=segmodel.model
  339. par={'modelSize':(inputShape[3],inputShape[2]),'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),'RGB_convert_first':True,
  340. 'weights':trtFile,'device':device,'max_threads':1,
  341. 'image_dir':'../../river_demo/images/road','out_dir' :'results'}
  342. #infer_usage()
  343. toONNX(seg_model,onnxFile,inputShape=inputShape,device=device)
  344. ONNXtoTrt(onnxFile,trtFile)
  345. #EngineInfer(par)
  346. #ONNX_eval()