基于Yolov7的路面病害检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

10 ay önce
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. #!/usr/bin/env python
  2. import argparse
  3. import numpy as np
  4. import sys
  5. import cv2
  6. import tritonclient.grpc as grpcclient
  7. from tritonclient.utils import InferenceServerException
  8. from processing import preprocess, postprocess
  9. from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
  10. from labels import COCOLabels
  11. INPUT_NAMES = ["images"]
  12. OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
  13. if __name__ == '__main__':
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('mode',
  16. choices=['dummy', 'image', 'video'],
  17. default='dummy',
  18. help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
  19. parser.add_argument('input',
  20. type=str,
  21. nargs='?',
  22. help='Input file to load from in image or video mode')
  23. parser.add_argument('-m',
  24. '--model',
  25. type=str,
  26. required=False,
  27. default='yolov7',
  28. help='Inference model name, default yolov7')
  29. parser.add_argument('--width',
  30. type=int,
  31. required=False,
  32. default=640,
  33. help='Inference model input width, default 640')
  34. parser.add_argument('--height',
  35. type=int,
  36. required=False,
  37. default=640,
  38. help='Inference model input height, default 640')
  39. parser.add_argument('-u',
  40. '--url',
  41. type=str,
  42. required=False,
  43. default='localhost:8001',
  44. help='Inference server URL, default localhost:8001')
  45. parser.add_argument('-o',
  46. '--out',
  47. type=str,
  48. required=False,
  49. default='',
  50. help='Write output into file instead of displaying it')
  51. parser.add_argument('-f',
  52. '--fps',
  53. type=float,
  54. required=False,
  55. default=24.0,
  56. help='Video output fps, default 24.0 FPS')
  57. parser.add_argument('-i',
  58. '--model-info',
  59. action="store_true",
  60. required=False,
  61. default=False,
  62. help='Print model status, configuration and statistics')
  63. parser.add_argument('-v',
  64. '--verbose',
  65. action="store_true",
  66. required=False,
  67. default=False,
  68. help='Enable verbose client output')
  69. parser.add_argument('-t',
  70. '--client-timeout',
  71. type=float,
  72. required=False,
  73. default=None,
  74. help='Client timeout in seconds, default no timeout')
  75. parser.add_argument('-s',
  76. '--ssl',
  77. action="store_true",
  78. required=False,
  79. default=False,
  80. help='Enable SSL encrypted channel to the server')
  81. parser.add_argument('-r',
  82. '--root-certificates',
  83. type=str,
  84. required=False,
  85. default=None,
  86. help='File holding PEM-encoded root certificates, default none')
  87. parser.add_argument('-p',
  88. '--private-key',
  89. type=str,
  90. required=False,
  91. default=None,
  92. help='File holding PEM-encoded private key, default is none')
  93. parser.add_argument('-x',
  94. '--certificate-chain',
  95. type=str,
  96. required=False,
  97. default=None,
  98. help='File holding PEM-encoded certicate chain default is none')
  99. FLAGS = parser.parse_args()
  100. # Create server context
  101. try:
  102. triton_client = grpcclient.InferenceServerClient(
  103. url=FLAGS.url,
  104. verbose=FLAGS.verbose,
  105. ssl=FLAGS.ssl,
  106. root_certificates=FLAGS.root_certificates,
  107. private_key=FLAGS.private_key,
  108. certificate_chain=FLAGS.certificate_chain)
  109. except Exception as e:
  110. print("context creation failed: " + str(e))
  111. sys.exit()
  112. # Health check
  113. if not triton_client.is_server_live():
  114. print("FAILED : is_server_live")
  115. sys.exit(1)
  116. if not triton_client.is_server_ready():
  117. print("FAILED : is_server_ready")
  118. sys.exit(1)
  119. if not triton_client.is_model_ready(FLAGS.model):
  120. print("FAILED : is_model_ready")
  121. sys.exit(1)
  122. if FLAGS.model_info:
  123. # Model metadata
  124. try:
  125. metadata = triton_client.get_model_metadata(FLAGS.model)
  126. print(metadata)
  127. except InferenceServerException as ex:
  128. if "Request for unknown model" not in ex.message():
  129. print("FAILED : get_model_metadata")
  130. print("Got: {}".format(ex.message()))
  131. sys.exit(1)
  132. else:
  133. print("FAILED : get_model_metadata")
  134. sys.exit(1)
  135. # Model configuration
  136. try:
  137. config = triton_client.get_model_config(FLAGS.model)
  138. if not (config.config.name == FLAGS.model):
  139. print("FAILED: get_model_config")
  140. sys.exit(1)
  141. print(config)
  142. except InferenceServerException as ex:
  143. print("FAILED : get_model_config")
  144. print("Got: {}".format(ex.message()))
  145. sys.exit(1)
  146. # DUMMY MODE
  147. if FLAGS.mode == 'dummy':
  148. print("Running in 'dummy' mode")
  149. print("Creating emtpy buffer filled with ones...")
  150. inputs = []
  151. outputs = []
  152. inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
  153. inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
  154. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
  155. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
  156. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
  157. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
  158. print("Invoking inference...")
  159. results = triton_client.infer(model_name=FLAGS.model,
  160. inputs=inputs,
  161. outputs=outputs,
  162. client_timeout=FLAGS.client_timeout)
  163. if FLAGS.model_info:
  164. statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
  165. if len(statistics.model_stats) != 1:
  166. print("FAILED: get_inference_statistics")
  167. sys.exit(1)
  168. print(statistics)
  169. print("Done")
  170. for output in OUTPUT_NAMES:
  171. result = results.as_numpy(output)
  172. print(f"Received result buffer \"{output}\" of size {result.shape}")
  173. print(f"Naive buffer sum: {np.sum(result)}")
  174. # IMAGE MODE
  175. if FLAGS.mode == 'image':
  176. print("Running in 'image' mode")
  177. if not FLAGS.input:
  178. print("FAILED: no input image")
  179. sys.exit(1)
  180. inputs = []
  181. outputs = []
  182. inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
  183. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
  184. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
  185. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
  186. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
  187. print("Creating buffer from image file...")
  188. input_image = cv2.imread(str(FLAGS.input))
  189. if input_image is None:
  190. print(f"FAILED: could not load input image {str(FLAGS.input)}")
  191. sys.exit(1)
  192. input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
  193. input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
  194. inputs[0].set_data_from_numpy(input_image_buffer)
  195. print("Invoking inference...")
  196. results = triton_client.infer(model_name=FLAGS.model,
  197. inputs=inputs,
  198. outputs=outputs,
  199. client_timeout=FLAGS.client_timeout)
  200. if FLAGS.model_info:
  201. statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
  202. if len(statistics.model_stats) != 1:
  203. print("FAILED: get_inference_statistics")
  204. sys.exit(1)
  205. print(statistics)
  206. print("Done")
  207. for output in OUTPUT_NAMES:
  208. result = results.as_numpy(output)
  209. print(f"Received result buffer \"{output}\" of size {result.shape}")
  210. print(f"Naive buffer sum: {np.sum(result)}")
  211. num_dets = results.as_numpy(OUTPUT_NAMES[0])
  212. det_boxes = results.as_numpy(OUTPUT_NAMES[1])
  213. det_scores = results.as_numpy(OUTPUT_NAMES[2])
  214. det_classes = results.as_numpy(OUTPUT_NAMES[3])
  215. detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
  216. print(f"Detected objects: {len(detected_objects)}")
  217. for box in detected_objects:
  218. print(f"{COCOLabels(box.classID).name}: {box.confidence}")
  219. input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
  220. size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
  221. input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
  222. input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
  223. if FLAGS.out:
  224. cv2.imwrite(FLAGS.out, input_image)
  225. print(f"Saved result to {FLAGS.out}")
  226. else:
  227. cv2.imshow('image', input_image)
  228. cv2.waitKey(0)
  229. cv2.destroyAllWindows()
  230. # VIDEO MODE
  231. if FLAGS.mode == 'video':
  232. print("Running in 'video' mode")
  233. if not FLAGS.input:
  234. print("FAILED: no input video")
  235. sys.exit(1)
  236. inputs = []
  237. outputs = []
  238. inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
  239. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
  240. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
  241. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
  242. outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
  243. print("Opening input video stream...")
  244. cap = cv2.VideoCapture(FLAGS.input)
  245. if not cap.isOpened():
  246. print(f"FAILED: cannot open video {FLAGS.input}")
  247. sys.exit(1)
  248. counter = 0
  249. out = None
  250. print("Invoking inference...")
  251. while True:
  252. ret, frame = cap.read()
  253. if not ret:
  254. print("failed to fetch next frame")
  255. break
  256. if counter == 0 and FLAGS.out:
  257. print("Opening output video stream...")
  258. fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
  259. out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
  260. input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
  261. input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
  262. inputs[0].set_data_from_numpy(input_image_buffer)
  263. results = triton_client.infer(model_name=FLAGS.model,
  264. inputs=inputs,
  265. outputs=outputs,
  266. client_timeout=FLAGS.client_timeout)
  267. num_dets = results.as_numpy("num_dets")
  268. det_boxes = results.as_numpy("det_boxes")
  269. det_scores = results.as_numpy("det_scores")
  270. det_classes = results.as_numpy("det_classes")
  271. detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
  272. print(f"Frame {counter}: {len(detected_objects)} objects")
  273. counter += 1
  274. for box in detected_objects:
  275. print(f"{COCOLabels(box.classID).name}: {box.confidence}")
  276. frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
  277. size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
  278. frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
  279. frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
  280. if FLAGS.out:
  281. out.write(frame)
  282. else:
  283. cv2.imshow('image', frame)
  284. if cv2.waitKey(1) == ord('q'):
  285. break
  286. if FLAGS.model_info:
  287. statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
  288. if len(statistics.model_stats) != 1:
  289. print("FAILED: get_inference_statistics")
  290. sys.exit(1)
  291. print(statistics)
  292. print("Done")
  293. cap.release()
  294. if FLAGS.out:
  295. out.release()
  296. else:
  297. cv2.destroyAllWindows()