TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.2KB

  1. import sys
  2. import argparse
  3. import os
  4. import struct
  5. import torch
  6. from utils.torch_utils import select_device
  7. def parse_args():
  8. parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
  9. parser.add_argument('-w', '--weights', required=True,
  10. help='Input weights (.pt) file path (required)')
  11. parser.add_argument(
  12. '-o', '--output', help='Output (.wts) file path (optional)')
  13. parser.add_argument(
  14. '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg'],
  15. help='determines the model is detection/classification')
  16. args = parser.parse_args()
  17. if not os.path.isfile(args.weights):
  18. raise SystemExit('Invalid input file')
  19. if not args.output:
  20. args.output = os.path.splitext(args.weights)[0] + '.wts'
  21. elif os.path.isdir(args.output):
  22. args.output = os.path.join(
  23. args.output,
  24. os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
  25. return args.weights, args.output, args.type
  26. pt_file, wts_file, m_type = parse_args()
  27. print(f'Generating .wts for {m_type} model')
  28. # Load model
  29. print(f'Loading {pt_file}')
  30. device = select_device('cpu')
  31. model = torch.load(pt_file, map_location=device) # Load FP32 weights
  32. model = model['ema' if model.get('ema') else 'model'].float()
  33. if m_type in ['detect', 'seg']:
  34. # update anchor_grid info
  35. anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
  36. # model.model[-1].anchor_grid = anchor_grid
  37. delattr(model.model[-1], 'anchor_grid') # model.model[-1] is detect layer
  38. # The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
  39. model.model[-1].register_buffer("anchor_grid", anchor_grid)
  40. model.model[-1].register_buffer("strides", model.model[-1].stride)
  41. model.to(device).eval()
  42. print(f'Writing into {wts_file}')
  43. with open(wts_file, 'w') as f:
  44. f.write('{}\n'.format(len(model.state_dict().keys())))
  45. for k, v in model.state_dict().items():
  46. vr = v.reshape(-1).cpu().numpy()
  47. f.write('{} {} '.format(k, len(vr)))
  48. for vv in vr:
  49. f.write(' ')
  50. f.write(struct.pack('>f', float(vv)).hex())
  51. f.write('\n')