You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
12KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Run inference on images, videos, directories, streams, etc.
  4. Usage:
  5. $ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
  6. img.jpg # image
  7. vid.mp4 # video
  8. path/ # directory
  9. path/*.jpg # glob
  10. 'https://youtu.be/Zgi9g1ksQHc' # YouTube
  11. 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
  12. """
  13. import argparse
  14. import os
  15. import sys
  16. from pathlib import Path
  17. import cv2
  18. import torch
  19. import torch.backends.cudnn as cudnn
  20. FILE = Path(__file__).resolve()
  21. ROOT = FILE.parents[0] # YOLOv5 root directory
  22. if str(ROOT) not in sys.path:
  23. sys.path.append(str(ROOT)) # add ROOT to PATH
  24. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  25. from models.common import DetectMultiBackend
  26. from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
  27. from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
  28. increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
  29. from utils.plots import Annotator, colors, save_one_box
  30. from utils.torch_utils import select_device, time_sync
  31. @torch.no_grad()
  32. def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
  33. source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
  34. imgsz=640, # inference size (pixels)
  35. conf_thres=0.25, # confidence threshold
  36. iou_thres=0.45, # NMS IOU threshold
  37. max_det=1000, # maximum detections per image
  38. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  39. view_img=False, # show results
  40. save_txt=False, # save results to *.txt
  41. save_conf=False, # save confidences in --save-txt labels
  42. save_crop=False, # save cropped prediction boxes
  43. nosave=False, # do not save images/videos
  44. classes=None, # filter by class: --class 0, or --class 0 2 3
  45. agnostic_nms=False, # class-agnostic NMS
  46. augment=False, # augmented inference
  47. visualize=False, # visualize features
  48. update=False, # update all models
  49. project=ROOT / 'runs/detect', # save results to project/name
  50. name='exp', # save results to project/name
  51. exist_ok=False, # existing project/name ok, do not increment
  52. line_thickness=3, # bounding box thickness (pixels)
  53. hide_labels=False, # hide labels
  54. hide_conf=False, # hide confidences
  55. half=False, # use FP16 half-precision inference
  56. dnn=False, # use OpenCV DNN for ONNX inference
  57. ):
  58. source = str(source)
  59. save_img = not nosave and not source.endswith('.txt') # save inference images
  60. is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
  61. is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
  62. webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
  63. if is_url and is_file:
  64. source = check_file(source) # download
  65. # Directories
  66. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
  67. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  68. # Load model
  69. device = select_device(device)
  70. model = DetectMultiBackend(weights, device=device, dnn=dnn)
  71. stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
  72. imgsz = check_img_size(imgsz, s=stride) # check image size
  73. # Half
  74. half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
  75. if pt or jit:
  76. model.model.half() if half else model.model.float()
  77. # Dataloader
  78. if webcam:
  79. view_img = check_imshow()
  80. cudnn.benchmark = True # set True to speed up constant image size inference
  81. dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
  82. bs = len(dataset) # batch_size
  83. else:
  84. dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
  85. bs = 1 # batch_size
  86. vid_path, vid_writer = [None] * bs, [None] * bs
  87. # Run inference
  88. model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
  89. dt, seen = [0.0, 0.0, 0.0], 0
  90. for path, im, im0s, vid_cap, s in dataset:
  91. t1 = time_sync()
  92. im = torch.from_numpy(im).to(device)
  93. im = im.half() if half else im.float() # uint8 to fp16/32
  94. im /= 255 # 0 - 255 to 0.0 - 1.0
  95. if len(im.shape) == 3:
  96. im = im[None] # expand for batch dim
  97. t2 = time_sync()
  98. dt[0] += t2 - t1
  99. # Inference
  100. visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
  101. pred = model(im, augment=augment, visualize=visualize)
  102. t3 = time_sync()
  103. dt[1] += t3 - t2
  104. # NMS
  105. pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
  106. dt[2] += time_sync() - t3
  107. # Second-stage classifier (optional)
  108. # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
  109. # Process predictions
  110. for i, det in enumerate(pred): # per image
  111. seen += 1
  112. if webcam: # batch_size >= 1
  113. p, im0, frame = path[i], im0s[i].copy(), dataset.count
  114. s += f'{i}: '
  115. else:
  116. p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
  117. p = Path(p) # to Path
  118. save_path = str(save_dir / p.name) # im.jpg
  119. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
  120. s += '%gx%g ' % im.shape[2:] # print string
  121. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  122. imc = im0.copy() if save_crop else im0 # for save_crop
  123. annotator = Annotator(im0, line_width=line_thickness, example=str(names))
  124. if len(det):
  125. # Rescale boxes from img_size to im0 size
  126. det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
  127. # Print results
  128. for c in det[:, -1].unique():
  129. n = (det[:, -1] == c).sum() # detections per class
  130. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  131. # Write results
  132. for *xyxy, conf, cls in reversed(det):
  133. if save_txt: # Write to file
  134. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  135. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
  136. with open(txt_path + '.txt', 'a') as f:
  137. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  138. if save_img or save_crop or view_img: # Add bbox to image
  139. c = int(cls) # integer class
  140. label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
  141. annotator.box_label(xyxy, label, color=colors(c, True))
  142. if save_crop:
  143. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
  144. # Print time (inference-only)
  145. LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
  146. # Stream results
  147. im0 = annotator.result()
  148. if view_img:
  149. cv2.imshow(str(p), im0)
  150. cv2.waitKey(1) # 1 millisecond
  151. # Save results (image with detections)
  152. if save_img:
  153. if dataset.mode == 'image':
  154. cv2.imwrite(save_path, im0)
  155. else: # 'video' or 'stream'
  156. if vid_path[i] != save_path: # new video
  157. vid_path[i] = save_path
  158. if isinstance(vid_writer[i], cv2.VideoWriter):
  159. vid_writer[i].release() # release previous video writer
  160. if vid_cap: # video
  161. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  162. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  163. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  164. else: # stream
  165. fps, w, h = 30, im0.shape[1], im0.shape[0]
  166. save_path += '.mp4'
  167. vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  168. vid_writer[i].write(im0)
  169. # Print results
  170. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  171. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
  172. if save_txt or save_img:
  173. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  174. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
  175. if update:
  176. strip_optimizer(weights) # update model (to fix SourceChangeWarning)
  177. def parse_opt():
  178. parser = argparse.ArgumentParser()
  179. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
  180. parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
  181. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  182. parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
  183. parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
  184. parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
  185. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  186. parser.add_argument('--view-img', action='store_true', help='show results')
  187. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  188. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  189. parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
  190. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  191. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
  192. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  193. parser.add_argument('--augment', action='store_true', help='augmented inference')
  194. parser.add_argument('--visualize', action='store_true', help='visualize features')
  195. parser.add_argument('--update', action='store_true', help='update all models')
  196. parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
  197. parser.add_argument('--name', default='exp', help='save results to project/name')
  198. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  199. parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
  200. parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
  201. parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
  202. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
  203. parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
  204. opt = parser.parse_args()
  205. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  206. print_args(FILE.stem, opt)
  207. return opt
  208. def main(opt):
  209. check_requirements(exclude=('tensorboard', 'thop'))
  210. run(**vars(opt))
  211. if __name__ == "__main__":
  212. opt = parse_opt()
  213. main(opt)