# # Copyright 1993-2020 NVIDIA Corporation. All rights reserved. # # NOTICE TO LICENSEE: # # This source code and/or documentation ("Licensed Deliverables") are # subject to NVIDIA intellectual property rights under U.S. and # international Copyright laws. # # These Licensed Deliverables contained herein is PROPRIETARY and # CONFIDENTIAL to NVIDIA and is being provided under the terms and # conditions of a form of NVIDIA software license agreement by and # between NVIDIA and Licensee ("License Agreement") or electronically # accepted by Licensee. Notwithstanding any terms or conditions to # the contrary in the License Agreement, reproduction or disclosure # of the Licensed Deliverables to any third party without the express # written consent of NVIDIA is prohibited. # # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE # OF THESE LICENSED DELIVERABLES. # # U.S. Government End Users. These Licensed Deliverables are a # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT # 1995), consisting of "commercial computer software" and "commercial # computer software documentation" as such terms are used in 48 # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government # only as a commercial end item. Consistent with 48 C.F.R.12.212 and # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all # U.S. Government End Users acquire the Licensed Deliverables with # only those rights set forth herein. # # Any use of the Licensed Deliverables in individual and commercial # software must include, in the user documentation and internal # comments to the code, the above Disclaimer and U.S. Government End # Users Notice. # import argparse import pycuda.driver as cuda import pycuda.autoinit import numpy as np import torch import tensorrt as trt import time import onnx import onnxruntime import os,sys,cv2 #from model.u2net import U2NET #cuda.init() model_names = ['u2net.onnx','u2net_dynamic_batch.onnx','u2net_dynamic_hw.onnx','u2net_dynamic_batch-hw.onnx' ] dynamic_batch = {'input':{0:'batch'}, 'output0':{0:'batch'}, 'output1':{0:'batch'}, 'output2':{0:'batch'}, 'output3':{0:'batch'}, 'output4':{0:'batch'}, 'output5':{0:'batch'}, 'output6':{0:'batch'}} dynamic_hw ={'input':{2:'H',3:'W'}, 'output0':{2:'H',3:'W'}, 'output1':{2:'H',3:'W'}, 'output2':{2:'H',3:'W'}, 'output3':{2:'H',3:'W'}, 'output4':{2:'H',3:'W'}, 'output5':{2:'H',3:'W'}, 'output6':{2:'H',3:'W'}} dynamic_batch_hw ={'input':{0:'batch',2:'H',3:'W'}, 'output0':{0:'batch',2:'H',3:'W'}, 'output1':{0:'batch',2:'H',3:'W'}, 'output2':{0:'batch',2:'H',3:'W'}, 'output3':{0:'batch',2:'H',3:'W'}, 'output4':{0:'batch',2:'H',3:'W'}, 'output5':{0:'batch',2:'H',3:'W'}, 'output6':{0:'batch',2:'H',3:'W'}} dynamic_=[None,dynamic_batch,dynamic_hw,dynamic_batch_hw] TRT_LOGGER = trt.Logger() def pth2onnx(pth_model,onnx_name,input_shape=(1,3,512,512),input_names=['input'],output_names=['output'],dynamix_axis=None): #pth_model:输入加载权重后的pth模型 #onnx_name:输出的onnx模型路径 #input_shape:模型输入的尺寸(建议尺寸) #input_names:模型输入的名字,list格式,可以有多个输入 #output_names:模型输入的名字,list格式,可以有多个输出 #dynamix_axis:字典格式,None-表示静态输入。每一个模型的输入输出都可以定义动态的维度 # 如dynamic_batch_hw ={'input':{0:'batch',2:'H',3:'W'}, 'output':{0:'batch',2:'H',3:'W'}}, # 表示input的B,H,W和output的B,H,W是动态尺寸 print('[I] beg to converting pth to onnx ...... ',dynamix_axis) input_tensor = torch.ones(input_shape) if next(pth_model.parameters()).is_cuda: input_tensor = input_tensor.to('cuda:0') with torch.no_grad(): torch.onnx.export(pth_model, input_tensor, onnx_name, opset_version=11, input_names=input_names, do_constant_folding=True, output_names=output_names, dynamic_axes=dynamix_axis) onnx_model = onnx.load(onnx_name) try: onnx.checker.check_model(onnx_model) except Exception as e: print('[Error] model incorrect:',e) else: print('[I] conver to onnx over in ', onnx_name) print('') def onnx_inference(onnx_input,model_name,outputName=['output0','output1','output2','output3','output4','output5','output6' ]): providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] print('8'*10, ' line125:',model_name) #outputName = ['pred_logits', 'pred_points'] onnx_session = onnxruntime.InferenceSession(model_name,providers=providers) try: onnx_output = onnx_session.run(outputName,onnx_input) except Exception as e: onnx_output=None print(e) return onnx_output def onnx2engine(onnx_file_path,engine_file_path,input_shape=[1,3,512,512],half=True,max_batch_size=1,input_profile_shapes=[None,None,None]): #onnx_file_path:输入的onnx路径 #engine_file_path:输出的trt模型路径 #input_shape:默认的模型输入尺寸, 如[1,3,512,512] ,如果是动态的可以为[1,3,-1,-1] #half:是否使用fp16,默认True #max_batch_size:最大的bachsize,默认是1 #input_profile_shapes:动态输入时输入的三个尺寸[最小尺寸,优化尺寸,最大尺寸],此时input_shape一定有-1 # 如(1,3,512,512),(1,3,1024,1024),(1,3,2048,2048), builder = trt.Builder(TRT_LOGGER) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) config = builder.create_builder_config() parser = trt.OnnxParser(network,TRT_LOGGER) runtime = trt.Runtime(TRT_LOGGER) # 最大内存占用,一般1G,trt特有的,一切与优化有关,显存溢出需要重新设置 config.max_workspace_size = 1<<30 #256MB if builder.platform_has_fast_fp16 and half: config.set_flag(trt.BuilderFlag.FP16) builder.max_batch_size = max_batch_size # 推理的时候要保证batch_size<=max_batch_size # parse model file if not os.path.exists(onnx_file_path): print(f'onnx file {onnx_file_path} not found,please run torch_2_onnx.py first to generate it') exit(0) print(f'Loading ONNX file from path {onnx_file_path}...') with open(onnx_file_path,'rb') as model: print('Beginning ONNX file parsing') if not parser.parse(model.read()): print('ERROR:Failed to parse the ONNX file') for error in range(parser.num_errors): print(parser.get_error(error)) return None # Static input setting network.get_input(0).shape=input_shape # Dynamic input setting 动态输入在builder的profile设置 # 为每个动态输入绑定一个profile if -1 in input_shape: profile = builder.create_optimization_profile() profile.set_shape(network.get_input(0).name,input_profile_shapes[0],input_profile_shapes[1],input_profile_shapes[2] )#最小的尺寸,常用的尺寸,最大的尺寸,推理时候输入需要在这个范围内 config.add_optimization_profile(profile) print('Completed parsing the ONNX file') print(f'Building an engine from file {onnx_file_path}; this may take a while...') t0 = time.time() engine = builder.build_engine(network,config) with open(engine_file_path,'wb') as f: # f.write(plan) f.write(engine.serialize()) t1 = time.time() print('Completed creating Engine:%s, %.1f'%(engine_file_path,t1-t0)) try: # Sometimes python2 does not understand FileNotFoundError FileNotFoundError except NameError: FileNotFoundError = IOError #EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) def GiB(val): return val * 1 << 30 def add_help(description): parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) args, _ = parser.parse_known_args() def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]): ''' Parses sample arguments. Args: description (str): Description of the sample. subfolder (str): The subfolder containing data relevant to this sample find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. Returns: str: Path of data directory. ''' # Standard command-line arguments for all samples. kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT]) args, _ = parser.parse_known_args() def get_data_path(data_dir): # If the subfolder exists, append it to the path, otherwise use the provided path as-is. data_path = os.path.join(data_dir, subfolder) if not os.path.exists(data_path): print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.") data_path = data_dir # Make sure data directory exists. if not (os.path.exists(data_path)): print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path)) return data_path data_paths = [get_data_path(data_dir) for data_dir in args.datadir] return data_paths, locate_files(data_paths, find_files) def locate_files(data_paths, filenames): """ Locates the specified files in the specified data directories. If a file exists in multiple data directories, the first directory is used. Args: data_paths (List[str]): The data directories. filename (List[str]): The names of the files to find. Returns: List[str]: The absolute paths of the files. Raises: FileNotFoundError if a file could not be located. """ found_files = [None] * len(filenames) for data_path in data_paths: # Find all requested files. for index, (found, filename) in enumerate(zip(found_files, filenames)): if not found: file_path = os.path.abspath(os.path.join(data_path, filename)) if os.path.exists(file_path): found_files[index] = file_path # Check that all files were found for f, filename in zip(found_files, filenames): if not f or not os.path.exists(f): raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths)) return found_files # Simple helper data class that's a little nicer to use than a 2-tuple. class HostDeviceMem(object): def __init__(self, host_mem, device_mem): self.host = host_mem self.device = device_mem def __str__(self): return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) def __repr__(self): return self.__str__() # Allocates all buffers required for an engine, i.e. host/device inputs/outputs. def allocate_buffers(engine,input_shape,streamFlag=True): inputs = [] outputs = [] bindings = [] if streamFlag: stream = cuda.Stream() else: stream=None for ib,binding in enumerate(engine): dims = engine.get_binding_shape(binding) #print(engine.get_binding_name(ib),dims,engine.max_batch_size) if -1 in dims: if isinstance(input_shape,list): dims = input_shape[ib] else: dims = input_shape # size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size #size = trt.volume(dims) * engine.max_batch_size size = trt.volume(dims) dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffers host_mem = cuda.pagelocked_empty(size, dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) # Append the device buffer to device bindings. bindings.append(int(device_mem)) # Append to the appropriate list. if engine.binding_is_input(binding): inputs.append(HostDeviceMem(host_mem, device_mem)) else: outputs.append(HostDeviceMem(host_mem, device_mem)) return inputs, outputs, bindings, stream # This function is generalized for multiple inputs/outputs. # inputs and outputs are expected to be lists of HostDeviceMem objects. def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): # Transfer input data to the GPU. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] # Run inference. context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) # Transfer predictions back from the GPU. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] # Synchronize the stream stream.synchronize() # Return only the host outputs. return [out.host for out in outputs] # This function is generalized for multiple inputs/outputs for full dimension networks. # inputs and outputs are expected to be lists of HostDeviceMem objects. def do_inference_v2(context, bindings, inputs, outputs, stream): # Transfer input data to the GPU. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] # Run inference. #stream.synchronize() #context.execute_v2(bindings) # 执行推 context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) # Transfer predictions back from the GPU. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] # Synchronize the stream stream.synchronize() # Return only the host outputs. return [out.host for out in outputs] def trt_inference( img,img_h,img_w,context,inputs,outputs,bindings,stream,input_name = 'input'): #输入: #img--np格式,NCHW #img_h,img_w--输入模型时图像的H,W。动态输入是需要知道。 #context--外面开辟的trt上下文 #inputs,outputs,bindings,stream--第一次处理图像时,开辟的内存及其地址绑定到trt的输出 #input_name--模型输入tensor的名字 #输出 #trt_outputs--为list格式,里面的元素是numpy格式 origin_inputshape = context.get_tensor_shape( input_name) #if origin_inputshape[-1]==-1: context.set_optimization_profile_async(0,stream.handle) origin_inputshape[-2],origin_inputshape[-1]=(img_h,img_w) context.set_input_shape(input_name, (origin_inputshape)) inputs[0].host = np.ascontiguousarray(img) trt_outputs = do_inference_v2(context,bindings=bindings,inputs=inputs,outputs=outputs,stream=stream) return trt_outputs # def do_inference_v3(context, bindings, inputs, outputs, stream,h_,w_): # ''' # Copy from https://github.com/zhaogangthu/keras-yolo3-ocr-tensorrt/blob/master/tensorRT_yolo3/common.py # # ''' # # Transfer input data to the GPU. # # context.set_binding_shape(0, (1, 3, h_, w_)) # # [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] # # Run inference. # context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) # # Transfer predictions back from the GPU. # [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] # # Synchronize the stream # stream.synchronize() # # Return only the host outputs. # return [out.host for out in outputs] if __name__=='__main__': model_path='weights/u2net_portrait.pth' onnx_name = model_path.replace('.pth','.onnx') trt_name = model_path.replace('.pth','.engine') pth_model = U2NET(3,1) pth_model.load_state_dict(torch.load(model_path)) input_names=['input'] output_names=['output%d'%(i) for i in range(7)] dynamix_axis = dynamic_hw input_shape =(1,3,512,512) #测试pth转为onnx模型 #pth2onnx(pth_model,onnx_name,input_shape=input_shape ,input_names=input_names ,output_names=output_names ,dynamix_axis=dynamix_axis ) #测试onnx模型转为trt模型 input_profile_shapes = [(1,3,512,512),(1,3,1024,1024),(1,3,2048,2048)] input_shape = [1,3,-1,-1] half=True max_batch_size = 1 onnx2engine(onnx_name,trt_name,input_shape=input_shape,half=half,max_batch_size=max_batch_size,input_profile_shapes=input_profile_shapes) ''' with torch.no_grad(): for i,model_name in enumerate(model_names): print(f'process model:{model_name}...') torch.onnx.export(model, input_tensor, model_name, opset_version=11, input_names=['input'], output_names=['output0','output1','output2','output3','output4','output5','output6'], dynamic_axes=dynamic_[i]) print(f'onnx model:{model_name} saved successfully...') #print('sleep 10s...') time.sleep(10) print(f'begin check onnx model:{model_name}...') onnx_model = onnx.load(model_name) try: onnx.checker.check_model(onnx_model) except Exception as e: print('model incorrect') print(e) else: print('model correct') print('*'*50) print('Begin to test...') case_1 = np.random.rand(1,3,512,512).astype(np.float32) case_2 = np.random.rand(2,3,512,512).astype(np.float32) case_3 = np.random.rand(1,3,224,224).astype(np.float32) cases = [case_1,case_2,case_3] providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] for model_name in model_names: print('-'*50,model_name) onnx_session = onnxruntime.InferenceSession(model_name,providers=providers) for i,case in enumerate(cases): onnx_input = {'input':case} try: onnx_output = onnx_session.run(['output0','output1','output2','output3','output4','output5','output6'],onnx_input)[0] except Exception as e: print(f'Input:{i} on model:{model_name} failed') print(e) else: print(f'Input:{i} on model:{model_name} succeed') '''