64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
|
|
|
|
from pathlib import Path
|
|
import torch
|
|
import os,sys
|
|
import argparse
|
|
|
|
sys.path.extend(['segutils'])
|
|
from model_stages import BiSeNet_STDC
|
|
from trtUtils2 import pth2onnx,onnx2engine,onnx_inference
|
|
|
|
def main(opt):
|
|
|
|
if opt.mWidth ==0 or opt.mHeight==0:
|
|
modelSize=None
|
|
else:
|
|
modelSize = ( int(opt.mHeight), int(opt.mWidth) )
|
|
model = BiSeNet_STDC(backbone='STDCNet813', n_classes=int(opt.nclass),
|
|
use_boundary_2=False, use_boundary_4=False,
|
|
use_boundary_8=True, use_boundary_16=False,
|
|
use_conv_last=False,
|
|
modelSize=modelSize
|
|
)
|
|
|
|
model.load_state_dict(torch.load(opt.weights.strip(), map_location='cuda:0' ))
|
|
#model= model.to(device)
|
|
|
|
|
|
|
|
|
|
#pth_model='../weights/best_mae.pth'
|
|
pth_model=opt.weights.strip()
|
|
onnx_name = pth_model.replace('.pth','_dynamic.onnx')
|
|
trt_name = onnx_name.replace('.onnx','.engine')
|
|
dynamic_hw ={'input':{0:'batch',2:'H',3:'W'},
|
|
'output0':{1:'C',2:'H',3:'W'},
|
|
'output1':{1:'C',2:'H',3:'W'},
|
|
|
|
}
|
|
|
|
|
|
inputShape =(1, 3, 128*4,128*4)#(bs,channels,height,width)
|
|
|
|
input_profile_shapes = [(1,3,256,256),(1,3,1024,1024),(1,3,2048,2048)]
|
|
|
|
pth2onnx(model,onnx_name,input_shape=(1,3,512,512),input_names=['input'],output_names=[ 'output0' ,'output1'],dynamix_axis=dynamic_hw)
|
|
|
|
onnx2engine(onnx_name,trt_name,input_shape=[1,3,-1,-1],half=True,max_batch_size=1,input_profile_shapes=input_profile_shapes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', type=str, default='stdc_360X640.pth', help='model path(s)')
|
|
parser.add_argument('--nclass', type=int, default=2, help='segmodel nclass')
|
|
parser.add_argument('--mWidth', type=int, default=640, help='segmodel mWdith')
|
|
parser.add_argument('--mHeight', type=int, default=360, help='segmodel mHeight')
|
|
opt = parser.parse_args()
|
|
|
|
main(opt) |