You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

465 lines
20KB

  1. #
  2. # Copyright 1993-2020 NVIDIA Corporation. All rights reserved.
  3. #
  4. # NOTICE TO LICENSEE:
  5. #
  6. # This source code and/or documentation ("Licensed Deliverables") are
  7. # subject to NVIDIA intellectual property rights under U.S. and
  8. # international Copyright laws.
  9. #
  10. # These Licensed Deliverables contained herein is PROPRIETARY and
  11. # CONFIDENTIAL to NVIDIA and is being provided under the terms and
  12. # conditions of a form of NVIDIA software license agreement by and
  13. # between NVIDIA and Licensee ("License Agreement") or electronically
  14. # accepted by Licensee. Notwithstanding any terms or conditions to
  15. # the contrary in the License Agreement, reproduction or disclosure
  16. # of the Licensed Deliverables to any third party without the express
  17. # written consent of NVIDIA is prohibited.
  18. #
  19. # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
  20. # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
  21. # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
  22. # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
  23. # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
  24. # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
  25. # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
  26. # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
  27. # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
  28. # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
  29. # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
  30. # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
  31. # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
  32. # OF THESE LICENSED DELIVERABLES.
  33. #
  34. # U.S. Government End Users. These Licensed Deliverables are a
  35. # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
  36. # 1995), consisting of "commercial computer software" and "commercial
  37. # computer software documentation" as such terms are used in 48
  38. # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
  39. # only as a commercial end item. Consistent with 48 C.F.R.12.212 and
  40. # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
  41. # U.S. Government End Users acquire the Licensed Deliverables with
  42. # only those rights set forth herein.
  43. #
  44. # Any use of the Licensed Deliverables in individual and commercial
  45. # software must include, in the user documentation and internal
  46. # comments to the code, the above Disclaimer and U.S. Government End
  47. # Users Notice.
  48. #
  49. import argparse
  50. import pycuda.driver as cuda
  51. import pycuda.autoinit
  52. import numpy as np
  53. import torch
  54. import tensorrt as trt
  55. import time
  56. import onnx
  57. import onnxruntime
  58. import os,sys,cv2
  59. #from model.u2net import U2NET
  60. #cuda.init()
  61. model_names = ['u2net.onnx','u2net_dynamic_batch.onnx','u2net_dynamic_hw.onnx','u2net_dynamic_batch-hw.onnx' ]
  62. dynamic_batch = {'input':{0:'batch'},
  63. 'output0':{0:'batch'},
  64. 'output1':{0:'batch'},
  65. 'output2':{0:'batch'},
  66. 'output3':{0:'batch'},
  67. 'output4':{0:'batch'},
  68. 'output5':{0:'batch'},
  69. 'output6':{0:'batch'}}
  70. dynamic_hw ={'input':{2:'H',3:'W'},
  71. 'output0':{2:'H',3:'W'},
  72. 'output1':{2:'H',3:'W'},
  73. 'output2':{2:'H',3:'W'},
  74. 'output3':{2:'H',3:'W'},
  75. 'output4':{2:'H',3:'W'},
  76. 'output5':{2:'H',3:'W'},
  77. 'output6':{2:'H',3:'W'}}
  78. dynamic_batch_hw ={'input':{0:'batch',2:'H',3:'W'},
  79. 'output0':{0:'batch',2:'H',3:'W'},
  80. 'output1':{0:'batch',2:'H',3:'W'},
  81. 'output2':{0:'batch',2:'H',3:'W'},
  82. 'output3':{0:'batch',2:'H',3:'W'},
  83. 'output4':{0:'batch',2:'H',3:'W'},
  84. 'output5':{0:'batch',2:'H',3:'W'},
  85. 'output6':{0:'batch',2:'H',3:'W'}}
  86. dynamic_=[None,dynamic_batch,dynamic_hw,dynamic_batch_hw]
  87. TRT_LOGGER = trt.Logger()
  88. def pth2onnx(pth_model,onnx_name,input_shape=(1,3,512,512),input_names=['input'],output_names=['output'],dynamix_axis=None):
  89. #pth_model:输入加载权重后的pth模型
  90. #onnx_name:输出的onnx模型路径
  91. #input_shape:模型输入的尺寸(建议尺寸)
  92. #input_names:模型输入的名字,list格式,可以有多个输入
  93. #output_names:模型输入的名字,list格式,可以有多个输出
  94. #dynamix_axis:字典格式,None-表示静态输入。每一个模型的输入输出都可以定义动态的维度
  95. # 如dynamic_batch_hw ={'input':{0:'batch',2:'H',3:'W'}, 'output':{0:'batch',2:'H',3:'W'}},
  96. # 表示input的B,H,W和output的B,H,W是动态尺寸
  97. print('[I] beg to converting pth to onnx ...... ',dynamix_axis)
  98. input_tensor = torch.ones(input_shape)
  99. if next(pth_model.parameters()).is_cuda:
  100. input_tensor = input_tensor.to('cuda:0')
  101. with torch.no_grad():
  102. torch.onnx.export(pth_model,
  103. input_tensor,
  104. onnx_name,
  105. opset_version=11,
  106. input_names=input_names,
  107. do_constant_folding=True,
  108. output_names=output_names,
  109. dynamic_axes=dynamix_axis)
  110. onnx_model = onnx.load(onnx_name)
  111. try:
  112. onnx.checker.check_model(onnx_model)
  113. except Exception as e:
  114. print('[Error] model incorrect:',e)
  115. else:
  116. print('[I] conver to onnx over in ', onnx_name)
  117. print('')
  118. def onnx_inference(onnx_input,model_name,outputName=['output0','output1','output2','output3','output4','output5','output6' ]):
  119. providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
  120. print('8'*10, ' line125:',model_name)
  121. #outputName = ['pred_logits', 'pred_points']
  122. onnx_session = onnxruntime.InferenceSession(model_name,providers=providers)
  123. try:
  124. onnx_output = onnx_session.run(outputName,onnx_input)
  125. except Exception as e:
  126. onnx_output=None
  127. print(e)
  128. return onnx_output
  129. def onnx2engine(onnx_file_path,engine_file_path,input_shape=[1,3,512,512],half=True,max_batch_size=1,input_profile_shapes=[None,None,None]):
  130. #onnx_file_path:输入的onnx路径
  131. #engine_file_path:输出的trt模型路径
  132. #input_shape:默认的模型输入尺寸, 如[1,3,512,512] ,如果是动态的可以为[1,3,-1,-1]
  133. #half:是否使用fp16,默认True
  134. #max_batch_size:最大的bachsize,默认是1
  135. #input_profile_shapes:动态输入时输入的三个尺寸[最小尺寸,优化尺寸,最大尺寸],此时input_shape一定有-1
  136. # 如(1,3,512,512),(1,3,1024,1024),(1,3,2048,2048),
  137. builder = trt.Builder(TRT_LOGGER)
  138. network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  139. config = builder.create_builder_config()
  140. parser = trt.OnnxParser(network,TRT_LOGGER)
  141. runtime = trt.Runtime(TRT_LOGGER)
  142. # 最大内存占用,一般1G,trt特有的,一切与优化有关,显存溢出需要重新设置
  143. config.max_workspace_size = 1<<30 #256MB
  144. if builder.platform_has_fast_fp16 and half:
  145. config.set_flag(trt.BuilderFlag.FP16)
  146. builder.max_batch_size = max_batch_size # 推理的时候要保证batch_size<=max_batch_size
  147. # parse model file
  148. if not os.path.exists(onnx_file_path):
  149. print(f'onnx file {onnx_file_path} not found,please run torch_2_onnx.py first to generate it')
  150. exit(0)
  151. print(f'Loading ONNX file from path {onnx_file_path}...')
  152. with open(onnx_file_path,'rb') as model:
  153. print('Beginning ONNX file parsing')
  154. if not parser.parse(model.read()):
  155. print('ERROR:Failed to parse the ONNX file')
  156. for error in range(parser.num_errors):
  157. print(parser.get_error(error))
  158. return None
  159. # Static input setting
  160. network.get_input(0).shape=input_shape
  161. # Dynamic input setting 动态输入在builder的profile设置
  162. # 为每个动态输入绑定一个profile
  163. if -1 in input_shape:
  164. profile = builder.create_optimization_profile()
  165. profile.set_shape(network.get_input(0).name,input_profile_shapes[0],input_profile_shapes[1],input_profile_shapes[2] )#最小的尺寸,常用的尺寸,最大的尺寸,推理时候输入需要在这个范围内
  166. config.add_optimization_profile(profile)
  167. print('Completed parsing the ONNX file')
  168. print(f'Building an engine from file {onnx_file_path}; this may take a while...')
  169. t0 = time.time()
  170. engine = builder.build_engine(network,config)
  171. with open(engine_file_path,'wb') as f:
  172. # f.write(plan)
  173. f.write(engine.serialize())
  174. t1 = time.time()
  175. print('Completed creating Engine:%s, %.1f'%(engine_file_path,t1-t0))
  176. try:
  177. # Sometimes python2 does not understand FileNotFoundError
  178. FileNotFoundError
  179. except NameError:
  180. FileNotFoundError = IOError
  181. #EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  182. def GiB(val):
  183. return val * 1 << 30
  184. def add_help(description):
  185. parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  186. args, _ = parser.parse_known_args()
  187. def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
  188. '''
  189. Parses sample arguments.
  190. Args:
  191. description (str): Description of the sample.
  192. subfolder (str): The subfolder containing data relevant to this sample
  193. find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
  194. Returns:
  195. str: Path of data directory.
  196. '''
  197. # Standard command-line arguments for all samples.
  198. kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
  199. parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  200. parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT])
  201. args, _ = parser.parse_known_args()
  202. def get_data_path(data_dir):
  203. # If the subfolder exists, append it to the path, otherwise use the provided path as-is.
  204. data_path = os.path.join(data_dir, subfolder)
  205. if not os.path.exists(data_path):
  206. print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
  207. data_path = data_dir
  208. # Make sure data directory exists.
  209. if not (os.path.exists(data_path)):
  210. print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
  211. return data_path
  212. data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
  213. return data_paths, locate_files(data_paths, find_files)
  214. def locate_files(data_paths, filenames):
  215. """
  216. Locates the specified files in the specified data directories.
  217. If a file exists in multiple data directories, the first directory is used.
  218. Args:
  219. data_paths (List[str]): The data directories.
  220. filename (List[str]): The names of the files to find.
  221. Returns:
  222. List[str]: The absolute paths of the files.
  223. Raises:
  224. FileNotFoundError if a file could not be located.
  225. """
  226. found_files = [None] * len(filenames)
  227. for data_path in data_paths:
  228. # Find all requested files.
  229. for index, (found, filename) in enumerate(zip(found_files, filenames)):
  230. if not found:
  231. file_path = os.path.abspath(os.path.join(data_path, filename))
  232. if os.path.exists(file_path):
  233. found_files[index] = file_path
  234. # Check that all files were found
  235. for f, filename in zip(found_files, filenames):
  236. if not f or not os.path.exists(f):
  237. raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths))
  238. return found_files
  239. # Simple helper data class that's a little nicer to use than a 2-tuple.
  240. class HostDeviceMem(object):
  241. def __init__(self, host_mem, device_mem):
  242. self.host = host_mem
  243. self.device = device_mem
  244. def __str__(self):
  245. return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
  246. def __repr__(self):
  247. return self.__str__()
  248. # Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
  249. def allocate_buffers(engine,input_shape,streamFlag=True):
  250. inputs = []
  251. outputs = []
  252. bindings = []
  253. if streamFlag:
  254. stream = cuda.Stream()
  255. else: stream=None
  256. for ib,binding in enumerate(engine):
  257. dims = engine.get_binding_shape(binding)
  258. #print(engine.get_binding_name(ib),dims,engine.max_batch_size)
  259. if -1 in dims:
  260. if isinstance(input_shape,list):
  261. dims = input_shape[ib]
  262. else:
  263. dims = input_shape
  264. # size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
  265. #size = trt.volume(dims) * engine.max_batch_size
  266. size = trt.volume(dims)
  267. dtype = trt.nptype(engine.get_binding_dtype(binding))
  268. # Allocate host and device buffers
  269. host_mem = cuda.pagelocked_empty(size, dtype)
  270. device_mem = cuda.mem_alloc(host_mem.nbytes)
  271. # Append the device buffer to device bindings.
  272. bindings.append(int(device_mem))
  273. # Append to the appropriate list.
  274. if engine.binding_is_input(binding):
  275. inputs.append(HostDeviceMem(host_mem, device_mem))
  276. else:
  277. outputs.append(HostDeviceMem(host_mem, device_mem))
  278. return inputs, outputs, bindings, stream
  279. # This function is generalized for multiple inputs/outputs.
  280. # inputs and outputs are expected to be lists of HostDeviceMem objects.
  281. def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
  282. # Transfer input data to the GPU.
  283. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  284. # Run inference.
  285. context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
  286. # Transfer predictions back from the GPU.
  287. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  288. # Synchronize the stream
  289. stream.synchronize()
  290. # Return only the host outputs.
  291. return [out.host for out in outputs]
  292. # This function is generalized for multiple inputs/outputs for full dimension networks.
  293. # inputs and outputs are expected to be lists of HostDeviceMem objects.
  294. def do_inference_v2(context, bindings, inputs, outputs, stream):
  295. # Transfer input data to the GPU.
  296. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  297. # Run inference.
  298. #stream.synchronize()
  299. #context.execute_v2(bindings) # 执行推
  300. context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
  301. # Transfer predictions back from the GPU.
  302. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  303. # Synchronize the stream
  304. stream.synchronize()
  305. # Return only the host outputs.
  306. return [out.host for out in outputs]
  307. def trt_inference( img,img_h,img_w,context,inputs,outputs,bindings,stream,input_name = 'input'):
  308. #输入:
  309. #img--np格式,NCHW
  310. #img_h,img_w--输入模型时图像的H,W。动态输入是需要知道。
  311. #context--外面开辟的trt上下文
  312. #inputs,outputs,bindings,stream--第一次处理图像时,开辟的内存及其地址绑定到trt的输出
  313. #input_name--模型输入tensor的名字
  314. #输出
  315. #trt_outputs--为list格式,里面的元素是numpy格式
  316. origin_inputshape = context.get_tensor_shape( input_name)
  317. #if origin_inputshape[-1]==-1:
  318. context.set_optimization_profile_async(0,stream.handle)
  319. origin_inputshape[-2],origin_inputshape[-1]=(img_h,img_w)
  320. context.set_input_shape(input_name, (origin_inputshape))
  321. inputs[0].host = np.ascontiguousarray(img)
  322. trt_outputs = do_inference_v2(context,bindings=bindings,inputs=inputs,outputs=outputs,stream=stream)
  323. return trt_outputs
  324. # def do_inference_v3(context, bindings, inputs, outputs, stream,h_,w_):
  325. # '''
  326. # Copy from https://github.com/zhaogangthu/keras-yolo3-ocr-tensorrt/blob/master/tensorRT_yolo3/common.py
  327. #
  328. # '''
  329. # # Transfer input data to the GPU.
  330. #
  331. # context.set_binding_shape(0, (1, 3, h_, w_))
  332. #
  333. # [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  334. # # Run inference.
  335. # context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
  336. # # Transfer predictions back from the GPU.
  337. # [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  338. # # Synchronize the stream
  339. # stream.synchronize()
  340. # # Return only the host outputs.
  341. # return [out.host for out in outputs]
  342. if __name__=='__main__':
  343. model_path='weights/u2net_portrait.pth'
  344. onnx_name = model_path.replace('.pth','.onnx')
  345. trt_name = model_path.replace('.pth','.engine')
  346. pth_model = U2NET(3,1)
  347. pth_model.load_state_dict(torch.load(model_path))
  348. input_names=['input']
  349. output_names=['output%d'%(i) for i in range(7)]
  350. dynamix_axis = dynamic_hw
  351. input_shape =(1,3,512,512)
  352. #测试pth转为onnx模型
  353. #pth2onnx(pth_model,onnx_name,input_shape=input_shape ,input_names=input_names ,output_names=output_names ,dynamix_axis=dynamix_axis )
  354. #测试onnx模型转为trt模型
  355. input_profile_shapes = [(1,3,512,512),(1,3,1024,1024),(1,3,2048,2048)]
  356. input_shape = [1,3,-1,-1]
  357. half=True
  358. max_batch_size = 1
  359. onnx2engine(onnx_name,trt_name,input_shape=input_shape,half=half,max_batch_size=max_batch_size,input_profile_shapes=input_profile_shapes)
  360. '''
  361. with torch.no_grad():
  362. for i,model_name in enumerate(model_names):
  363. print(f'process model:{model_name}...')
  364. torch.onnx.export(model,
  365. input_tensor,
  366. model_name,
  367. opset_version=11,
  368. input_names=['input'],
  369. output_names=['output0','output1','output2','output3','output4','output5','output6'],
  370. dynamic_axes=dynamic_[i])
  371. print(f'onnx model:{model_name} saved successfully...')
  372. #print('sleep 10s...')
  373. time.sleep(10)
  374. print(f'begin check onnx model:{model_name}...')
  375. onnx_model = onnx.load(model_name)
  376. try:
  377. onnx.checker.check_model(onnx_model)
  378. except Exception as e:
  379. print('model incorrect')
  380. print(e)
  381. else:
  382. print('model correct')
  383. print('*'*50)
  384. print('Begin to test...')
  385. case_1 = np.random.rand(1,3,512,512).astype(np.float32)
  386. case_2 = np.random.rand(2,3,512,512).astype(np.float32)
  387. case_3 = np.random.rand(1,3,224,224).astype(np.float32)
  388. cases = [case_1,case_2,case_3]
  389. providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
  390. for model_name in model_names:
  391. print('-'*50,model_name)
  392. onnx_session = onnxruntime.InferenceSession(model_name,providers=providers)
  393. for i,case in enumerate(cases):
  394. onnx_input = {'input':case}
  395. try:
  396. onnx_output = onnx_session.run(['output0','output1','output2','output3','output4','output5','output6'],onnx_input)[0]
  397. except Exception as e:
  398. print(f'Input:{i} on model:{model_name} failed')
  399. print(e)
  400. else:
  401. print(f'Input:{i} on model:{model_name} succeed')
  402. '''