|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- #!/usr/bin/env python
-
- import argparse
- import numpy as np
- import sys
- import cv2
-
- import tritonclient.grpc as grpcclient
- from tritonclient.utils import InferenceServerException
-
- from processing import preprocess, postprocess
- from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
- from labels import COCOLabels
-
- INPUT_NAMES = ["images"]
- OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('mode',
- choices=['dummy', 'image', 'video'],
- default='dummy',
- help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
- parser.add_argument('input',
- type=str,
- nargs='?',
- help='Input file to load from in image or video mode')
- parser.add_argument('-m',
- '--model',
- type=str,
- required=False,
- default='yolov7',
- help='Inference model name, default yolov7')
- parser.add_argument('--width',
- type=int,
- required=False,
- default=640,
- help='Inference model input width, default 640')
- parser.add_argument('--height',
- type=int,
- required=False,
- default=640,
- help='Inference model input height, default 640')
- parser.add_argument('-u',
- '--url',
- type=str,
- required=False,
- default='localhost:8001',
- help='Inference server URL, default localhost:8001')
- parser.add_argument('-o',
- '--out',
- type=str,
- required=False,
- default='',
- help='Write output into file instead of displaying it')
- parser.add_argument('-f',
- '--fps',
- type=float,
- required=False,
- default=24.0,
- help='Video output fps, default 24.0 FPS')
- parser.add_argument('-i',
- '--model-info',
- action="store_true",
- required=False,
- default=False,
- help='Print model status, configuration and statistics')
- parser.add_argument('-v',
- '--verbose',
- action="store_true",
- required=False,
- default=False,
- help='Enable verbose client output')
- parser.add_argument('-t',
- '--client-timeout',
- type=float,
- required=False,
- default=None,
- help='Client timeout in seconds, default no timeout')
- parser.add_argument('-s',
- '--ssl',
- action="store_true",
- required=False,
- default=False,
- help='Enable SSL encrypted channel to the server')
- parser.add_argument('-r',
- '--root-certificates',
- type=str,
- required=False,
- default=None,
- help='File holding PEM-encoded root certificates, default none')
- parser.add_argument('-p',
- '--private-key',
- type=str,
- required=False,
- default=None,
- help='File holding PEM-encoded private key, default is none')
- parser.add_argument('-x',
- '--certificate-chain',
- type=str,
- required=False,
- default=None,
- help='File holding PEM-encoded certicate chain default is none')
-
- FLAGS = parser.parse_args()
-
- # Create server context
- try:
- triton_client = grpcclient.InferenceServerClient(
- url=FLAGS.url,
- verbose=FLAGS.verbose,
- ssl=FLAGS.ssl,
- root_certificates=FLAGS.root_certificates,
- private_key=FLAGS.private_key,
- certificate_chain=FLAGS.certificate_chain)
- except Exception as e:
- print("context creation failed: " + str(e))
- sys.exit()
-
- # Health check
- if not triton_client.is_server_live():
- print("FAILED : is_server_live")
- sys.exit(1)
-
- if not triton_client.is_server_ready():
- print("FAILED : is_server_ready")
- sys.exit(1)
-
- if not triton_client.is_model_ready(FLAGS.model):
- print("FAILED : is_model_ready")
- sys.exit(1)
-
- if FLAGS.model_info:
- # Model metadata
- try:
- metadata = triton_client.get_model_metadata(FLAGS.model)
- print(metadata)
- except InferenceServerException as ex:
- if "Request for unknown model" not in ex.message():
- print("FAILED : get_model_metadata")
- print("Got: {}".format(ex.message()))
- sys.exit(1)
- else:
- print("FAILED : get_model_metadata")
- sys.exit(1)
-
- # Model configuration
- try:
- config = triton_client.get_model_config(FLAGS.model)
- if not (config.config.name == FLAGS.model):
- print("FAILED: get_model_config")
- sys.exit(1)
- print(config)
- except InferenceServerException as ex:
- print("FAILED : get_model_config")
- print("Got: {}".format(ex.message()))
- sys.exit(1)
-
- # DUMMY MODE
- if FLAGS.mode == 'dummy':
- print("Running in 'dummy' mode")
- print("Creating emtpy buffer filled with ones...")
- inputs = []
- outputs = []
- inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
- inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
-
- print("Invoking inference...")
- results = triton_client.infer(model_name=FLAGS.model,
- inputs=inputs,
- outputs=outputs,
- client_timeout=FLAGS.client_timeout)
- if FLAGS.model_info:
- statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
- if len(statistics.model_stats) != 1:
- print("FAILED: get_inference_statistics")
- sys.exit(1)
- print(statistics)
- print("Done")
-
- for output in OUTPUT_NAMES:
- result = results.as_numpy(output)
- print(f"Received result buffer \"{output}\" of size {result.shape}")
- print(f"Naive buffer sum: {np.sum(result)}")
-
- # IMAGE MODE
- if FLAGS.mode == 'image':
- print("Running in 'image' mode")
- if not FLAGS.input:
- print("FAILED: no input image")
- sys.exit(1)
-
- inputs = []
- outputs = []
- inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
-
- print("Creating buffer from image file...")
- input_image = cv2.imread(str(FLAGS.input))
- if input_image is None:
- print(f"FAILED: could not load input image {str(FLAGS.input)}")
- sys.exit(1)
- input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
- input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
-
- inputs[0].set_data_from_numpy(input_image_buffer)
-
- print("Invoking inference...")
- results = triton_client.infer(model_name=FLAGS.model,
- inputs=inputs,
- outputs=outputs,
- client_timeout=FLAGS.client_timeout)
- if FLAGS.model_info:
- statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
- if len(statistics.model_stats) != 1:
- print("FAILED: get_inference_statistics")
- sys.exit(1)
- print(statistics)
- print("Done")
-
- for output in OUTPUT_NAMES:
- result = results.as_numpy(output)
- print(f"Received result buffer \"{output}\" of size {result.shape}")
- print(f"Naive buffer sum: {np.sum(result)}")
-
- num_dets = results.as_numpy(OUTPUT_NAMES[0])
- det_boxes = results.as_numpy(OUTPUT_NAMES[1])
- det_scores = results.as_numpy(OUTPUT_NAMES[2])
- det_classes = results.as_numpy(OUTPUT_NAMES[3])
- detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
- print(f"Detected objects: {len(detected_objects)}")
-
- for box in detected_objects:
- print(f"{COCOLabels(box.classID).name}: {box.confidence}")
- input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
- size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
- input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
- input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
-
- if FLAGS.out:
- cv2.imwrite(FLAGS.out, input_image)
- print(f"Saved result to {FLAGS.out}")
- else:
- cv2.imshow('image', input_image)
- cv2.waitKey(0)
- cv2.destroyAllWindows()
-
- # VIDEO MODE
- if FLAGS.mode == 'video':
- print("Running in 'video' mode")
- if not FLAGS.input:
- print("FAILED: no input video")
- sys.exit(1)
-
- inputs = []
- outputs = []
- inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
- outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
-
- print("Opening input video stream...")
- cap = cv2.VideoCapture(FLAGS.input)
- if not cap.isOpened():
- print(f"FAILED: cannot open video {FLAGS.input}")
- sys.exit(1)
-
- counter = 0
- out = None
- print("Invoking inference...")
- while True:
- ret, frame = cap.read()
- if not ret:
- print("failed to fetch next frame")
- break
-
- if counter == 0 and FLAGS.out:
- print("Opening output video stream...")
- fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
- out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
-
- input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
- input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
-
- inputs[0].set_data_from_numpy(input_image_buffer)
-
- results = triton_client.infer(model_name=FLAGS.model,
- inputs=inputs,
- outputs=outputs,
- client_timeout=FLAGS.client_timeout)
-
- num_dets = results.as_numpy("num_dets")
- det_boxes = results.as_numpy("det_boxes")
- det_scores = results.as_numpy("det_scores")
- det_classes = results.as_numpy("det_classes")
- detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
- print(f"Frame {counter}: {len(detected_objects)} objects")
- counter += 1
-
- for box in detected_objects:
- print(f"{COCOLabels(box.classID).name}: {box.confidence}")
- frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
- size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
- frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
- frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
-
- if FLAGS.out:
- out.write(frame)
- else:
- cv2.imshow('image', frame)
- if cv2.waitKey(1) == ord('q'):
- break
-
- if FLAGS.model_info:
- statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
- if len(statistics.model_stats) != 1:
- print("FAILED: get_inference_statistics")
- sys.exit(1)
- print(statistics)
- print("Done")
-
- cap.release()
- if FLAGS.out:
- out.release()
- else:
- cv2.destroyAllWindows()
|