You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 satır
12KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Run inference on images, videos, directories, streams, etc.
  4. Usage:
  5. $ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
  6. img.jpg # image
  7. vid.mp4 # video
  8. path/ # directory
  9. path/*.jpg # glob
  10. 'https://youtu.be/Zgi9g1ksQHc' # YouTube
  11. 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
  12. """
  13. import argparse
  14. import os
  15. import sys
  16. from pathlib import Path
  17. import cv2
  18. import torch
  19. import torch.backends.cudnn as cudnn
  20. FILE = Path(__file__).resolve()
  21. ROOT = FILE.parents[0] # YOLOv5 root directory
  22. if str(ROOT) not in sys.path:
  23. sys.path.append(str(ROOT)) # add ROOT to PATH
  24. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  25. from models.common import DetectMultiBackend
  26. from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
  27. from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
  28. increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
  29. from utils.plots import Annotator, colors, save_one_box
  30. from utils.torch_utils import select_device, time_sync
  31. @torch.no_grad()
  32. def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
  33. source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
  34. imgsz=640, # inference size (pixels)
  35. conf_thres=0.25, # confidence threshold
  36. iou_thres=0.45, # NMS IOU threshold
  37. max_det=1000, # maximum detections per image
  38. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  39. view_img=False, # show results
  40. save_txt=False, # save results to *.txt
  41. save_conf=False, # save confidences in --save-txt labels
  42. save_crop=False, # save cropped prediction boxes
  43. nosave=False, # do not save images/videos
  44. classes=None, # filter by class: --class 0, or --class 0 2 3
  45. agnostic_nms=False, # class-agnostic NMS
  46. augment=False, # augmented inference
  47. visualize=False, # visualize features
  48. update=False, # update all models
  49. project=ROOT / 'runs/detect', # save results to project/name
  50. name='exp', # save results to project/name
  51. exist_ok=False, # existing project/name ok, do not increment
  52. line_thickness=3, # bounding box thickness (pixels)
  53. hide_labels=False, # hide labels
  54. hide_conf=False, # hide confidences
  55. half=False, # use FP16 half-precision inference
  56. dnn=False, # use OpenCV DNN for ONNX inference
  57. ):
  58. source = str(source)
  59. save_img = not nosave and not source.endswith('.txt') # save inference images
  60. is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
  61. is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
  62. webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
  63. if is_url and is_file:
  64. source = check_file(source) # download
  65. # Directories
  66. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
  67. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  68. # Load model
  69. device = select_device(device)
  70. model = DetectMultiBackend(weights, device=device, dnn=dnn)
  71. stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
  72. imgsz = check_img_size(imgsz, s=stride) # check image size
  73. # Half
  74. half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
  75. if pt:
  76. model.model.half() if half else model.model.float()
  77. # Dataloader
  78. if webcam:
  79. view_img = check_imshow()
  80. cudnn.benchmark = True # set True to speed up constant image size inference
  81. dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
  82. bs = len(dataset) # batch_size
  83. else:
  84. dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
  85. bs = 1 # batch_size
  86. vid_path, vid_writer = [None] * bs, [None] * bs
  87. # Run inference
  88. if pt and device.type != 'cpu':
  89. model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
  90. dt, seen = [0.0, 0.0, 0.0], 0
  91. for path, im, im0s, vid_cap, s in dataset:
  92. t1 = time_sync()
  93. im = torch.from_numpy(im).to(device)
  94. im = im.half() if half else im.float() # uint8 to fp16/32
  95. im /= 255 # 0 - 255 to 0.0 - 1.0
  96. if len(im.shape) == 3:
  97. im = im[None] # expand for batch dim
  98. t2 = time_sync()
  99. dt[0] += t2 - t1
  100. # Inference
  101. visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
  102. pred = model(im, augment=augment, visualize=visualize)
  103. t3 = time_sync()
  104. dt[1] += t3 - t2
  105. # NMS
  106. pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
  107. dt[2] += time_sync() - t3
  108. # Second-stage classifier (optional)
  109. # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
  110. # Process predictions
  111. for i, det in enumerate(pred): # per image
  112. seen += 1
  113. if webcam: # batch_size >= 1
  114. p, im0, frame = path[i], im0s[i].copy(), dataset.count
  115. s += f'{i}: '
  116. else:
  117. p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
  118. p = Path(p) # to Path
  119. save_path = str(save_dir / p.name) # im.jpg
  120. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
  121. s += '%gx%g ' % im.shape[2:] # print string
  122. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  123. imc = im0.copy() if save_crop else im0 # for save_crop
  124. annotator = Annotator(im0, line_width=line_thickness, example=str(names))
  125. if len(det):
  126. # Rescale boxes from img_size to im0 size
  127. det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
  128. # Print results
  129. for c in det[:, -1].unique():
  130. n = (det[:, -1] == c).sum() # detections per class
  131. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  132. # Write results
  133. for *xyxy, conf, cls in reversed(det):
  134. if save_txt: # Write to file
  135. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  136. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
  137. with open(txt_path + '.txt', 'a') as f:
  138. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  139. if save_img or save_crop or view_img: # Add bbox to image
  140. c = int(cls) # integer class
  141. label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
  142. annotator.box_label(xyxy, label, color=colors(c, True))
  143. if save_crop:
  144. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
  145. # Print time (inference-only)
  146. LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
  147. # Stream results
  148. im0 = annotator.result()
  149. if view_img:
  150. cv2.imshow(str(p), im0)
  151. cv2.waitKey(1) # 1 millisecond
  152. # Save results (image with detections)
  153. if save_img:
  154. if dataset.mode == 'image':
  155. cv2.imwrite(save_path, im0)
  156. else: # 'video' or 'stream'
  157. if vid_path[i] != save_path: # new video
  158. vid_path[i] = save_path
  159. if isinstance(vid_writer[i], cv2.VideoWriter):
  160. vid_writer[i].release() # release previous video writer
  161. if vid_cap: # video
  162. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  163. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  164. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  165. else: # stream
  166. fps, w, h = 30, im0.shape[1], im0.shape[0]
  167. save_path += '.mp4'
  168. vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  169. vid_writer[i].write(im0)
  170. # Print results
  171. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  172. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
  173. if save_txt or save_img:
  174. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  175. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
  176. if update:
  177. strip_optimizer(weights) # update model (to fix SourceChangeWarning)
  178. def parse_opt():
  179. parser = argparse.ArgumentParser()
  180. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
  181. parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
  182. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  183. parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
  184. parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
  185. parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
  186. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  187. parser.add_argument('--view-img', action='store_true', help='show results')
  188. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  189. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  190. parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
  191. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  192. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
  193. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  194. parser.add_argument('--augment', action='store_true', help='augmented inference')
  195. parser.add_argument('--visualize', action='store_true', help='visualize features')
  196. parser.add_argument('--update', action='store_true', help='update all models')
  197. parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
  198. parser.add_argument('--name', default='exp', help='save results to project/name')
  199. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  200. parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
  201. parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
  202. parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
  203. parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
  204. parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
  205. opt = parser.parse_args()
  206. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  207. print_args(FILE.stem, opt)
  208. return opt
  209. def main(opt):
  210. check_requirements(exclude=('tensorboard', 'thop'))
  211. run(**vars(opt))
  212. if __name__ == "__main__":
  213. opt = parse_opt()
  214. main(opt)