You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
9.3KB

  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from models.experimental import attempt_load
  8. from utils.datasets import LoadStreams, LoadImages
  9. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  10. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
  11. from utils.plots import colors, plot_one_box
  12. from utils.torch_utils import select_device, load_classifier, time_synchronized
  13. @torch.no_grad()
  14. def detect(opt):
  15. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  16. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  17. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  18. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  19. # Directories
  20. save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
  21. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  22. # Initialize
  23. set_logging()
  24. device = select_device(opt.device)
  25. half = opt.half and device.type != 'cpu' # half precision only supported on CUDA
  26. # Load model
  27. model = attempt_load(weights, map_location=device) # load FP32 model
  28. stride = int(model.stride.max()) # model stride
  29. imgsz = check_img_size(imgsz, s=stride) # check img_size
  30. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  31. if half:
  32. model.half() # to FP16
  33. # Second-stage classifier
  34. classify = False
  35. if classify:
  36. modelc = load_classifier(name='resnet101', n=2) # initialize
  37. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  38. # Set Dataloader
  39. vid_path, vid_writer = None, None
  40. if webcam:
  41. view_img = check_imshow()
  42. cudnn.benchmark = True # set True to speed up constant image size inference
  43. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  44. else:
  45. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  46. # Run inference
  47. if device.type != 'cpu':
  48. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  49. t0 = time.time()
  50. for path, img, im0s, vid_cap in dataset:
  51. img = torch.from_numpy(img).to(device)
  52. img = img.half() if half else img.float() # uint8 to fp16/32
  53. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  54. if img.ndimension() == 3:
  55. img = img.unsqueeze(0)
  56. # Inference
  57. t1 = time_synchronized()
  58. pred = model(img, augment=opt.augment)[0]
  59. # Apply NMS
  60. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms,
  61. max_det=opt.max_det)
  62. t2 = time_synchronized()
  63. # Apply Classifier
  64. if classify:
  65. pred = apply_classifier(pred, modelc, img, im0s)
  66. # Process detections
  67. for i, det in enumerate(pred): # detections per image
  68. if webcam: # batch_size >= 1
  69. p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
  70. else:
  71. p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
  72. p = Path(p) # to Path
  73. save_path = str(save_dir / p.name) # img.jpg
  74. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  75. s += '%gx%g ' % img.shape[2:] # print string
  76. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  77. imc = im0.copy() if opt.save_crop else im0 # for opt.save_crop
  78. if len(det):
  79. # Rescale boxes from img_size to im0 size
  80. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  81. # Print results
  82. for c in det[:, -1].unique():
  83. n = (det[:, -1] == c).sum() # detections per class
  84. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  85. # Write results
  86. for *xyxy, conf, cls in reversed(det):
  87. if save_txt: # Write to file
  88. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  89. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  90. with open(txt_path + '.txt', 'a') as f:
  91. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  92. if save_img or opt.save_crop or view_img: # Add bbox to image
  93. c = int(cls) # integer class
  94. label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
  95. plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
  96. if opt.save_crop:
  97. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
  98. # Print time (inference + NMS)
  99. print(f'{s}Done. ({t2 - t1:.3f}s)')
  100. # Stream results
  101. if view_img:
  102. cv2.imshow(str(p), im0)
  103. cv2.waitKey(1) # 1 millisecond
  104. # Save results (image with detections)
  105. if save_img:
  106. if dataset.mode == 'image':
  107. cv2.imwrite(save_path, im0)
  108. else: # 'video' or 'stream'
  109. if vid_path != save_path: # new video
  110. vid_path = save_path
  111. if isinstance(vid_writer, cv2.VideoWriter):
  112. vid_writer.release() # release previous video writer
  113. if vid_cap: # video
  114. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  115. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  116. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  117. else: # stream
  118. fps, w, h = 30, im0.shape[1], im0.shape[0]
  119. save_path += '.mp4'
  120. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  121. vid_writer.write(im0)
  122. if save_txt or save_img:
  123. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  124. print(f"Results saved to {save_dir}{s}")
  125. print(f'Done. ({time.time() - t0:.3f}s)')
  126. if __name__ == '__main__':
  127. parser = argparse.ArgumentParser()
  128. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  129. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  130. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  131. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  132. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  133. parser.add_argument('--max-det', type=int, default=1000, help='maximum number of detections per image')
  134. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  135. parser.add_argument('--view-img', action='store_true', help='display results')
  136. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  137. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  138. parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
  139. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  140. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  141. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  142. parser.add_argument('--augment', action='store_true', help='augmented inference')
  143. parser.add_argument('--update', action='store_true', help='update all models')
  144. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  145. parser.add_argument('--name', default='exp', help='save results to project/name')
  146. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  147. parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
  148. parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
  149. parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
  150. parser.add_argument('--half', type=bool, default=False, help='use FP16 half-precision inference')
  151. opt = parser.parse_args()
  152. print(opt)
  153. check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
  154. if opt.update: # update all models (to fix SourceChangeWarning)
  155. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  156. detect(opt=opt)
  157. strip_optimizer(opt.weights)
  158. else:
  159. detect(opt=opt)