You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 line
7.8KB

  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from models.experimental import attempt_load
  9. from utils.datasets import LoadStreams, LoadImages
  10. from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
  11. strip_optimizer, set_logging, increment_path
  12. from utils.plots import plot_one_box
  13. from utils.torch_utils import select_device, load_classifier, time_synchronized
  14. def detect(save_img=False):
  15. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  16. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  17. ('rtsp://', 'rtmp://', 'http://'))
  18. # Directories
  19. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  20. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  21. # Initialize
  22. set_logging()
  23. device = select_device(opt.device)
  24. half = device.type != 'cpu' # half precision only supported on CUDA
  25. # Load model
  26. model = attempt_load(weights, map_location=device) # load FP32 model
  27. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  28. if half:
  29. model.half() # to FP16
  30. # Second-stage classifier
  31. classify = False
  32. if classify:
  33. modelc = load_classifier(name='resnet101', n=2) # initialize
  34. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  35. # Set Dataloader
  36. vid_path, vid_writer = None, None
  37. if webcam:
  38. view_img = True
  39. cudnn.benchmark = True # set True to speed up constant image size inference
  40. dataset = LoadStreams(source, img_size=imgsz)
  41. else:
  42. save_img = True
  43. dataset = LoadImages(source, img_size=imgsz)
  44. # Get names and colors
  45. names = model.module.names if hasattr(model, 'module') else model.names
  46. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  47. # Run inference
  48. t0 = time.time()
  49. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  50. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  51. for path, img, im0s, vid_cap in dataset:
  52. img = torch.from_numpy(img).to(device)
  53. img = img.half() if half else img.float() # uint8 to fp16/32
  54. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  55. if img.ndimension() == 3:
  56. img = img.unsqueeze(0)
  57. # Inference
  58. t1 = time_synchronized()
  59. pred = model(img, augment=opt.augment)[0]
  60. # Apply NMS
  61. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  62. t2 = time_synchronized()
  63. # Apply Classifier
  64. if classify:
  65. pred = apply_classifier(pred, modelc, img, im0s)
  66. # Process detections
  67. for i, det in enumerate(pred): # detections per image
  68. if webcam: # batch_size >= 1
  69. p, s, im0 = Path(path[i]), '%g: ' % i, im0s[i].copy()
  70. else:
  71. p, s, im0 = Path(path), '', im0s
  72. save_path = str(save_dir / p.name)
  73. txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  74. s += '%gx%g ' % img.shape[2:] # print string
  75. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  76. if det is not None and len(det):
  77. # Rescale boxes from img_size to im0 size
  78. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  79. # Print results
  80. for c in det[:, -1].unique():
  81. n = (det[:, -1] == c).sum() # detections per class
  82. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  83. # Write results
  84. for *xyxy, conf, cls in reversed(det):
  85. if save_txt: # Write to file
  86. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  87. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  88. with open(txt_path + '.txt', 'a') as f:
  89. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  90. if save_img or view_img: # Add bbox to image
  91. label = '%s %.2f' % (names[int(cls)], conf)
  92. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  93. # Print time (inference + NMS)
  94. print('%sDone. (%.3fs)' % (s, t2 - t1))
  95. # Stream results
  96. if view_img:
  97. cv2.imshow(p, im0)
  98. if cv2.waitKey(1) == ord('q'): # q to quit
  99. raise StopIteration
  100. # Save results (image with detections)
  101. if save_img:
  102. if dataset.mode == 'images':
  103. cv2.imwrite(save_path, im0)
  104. else:
  105. if vid_path != save_path: # new video
  106. vid_path = save_path
  107. if isinstance(vid_writer, cv2.VideoWriter):
  108. vid_writer.release() # release previous video writer
  109. fourcc = 'mp4v' # output video codec
  110. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  111. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  112. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  113. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  114. vid_writer.write(im0)
  115. if save_txt or save_img:
  116. print('Results saved to %s' % save_dir)
  117. print('Done. (%.3fs)' % (time.time() - t0))
  118. if __name__ == '__main__':
  119. parser = argparse.ArgumentParser()
  120. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  121. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  122. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  123. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  124. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  125. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  126. parser.add_argument('--view-img', action='store_true', help='display results')
  127. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  128. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  129. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  130. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  131. parser.add_argument('--augment', action='store_true', help='augmented inference')
  132. parser.add_argument('--update', action='store_true', help='update all models')
  133. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  134. parser.add_argument('--name', default='exp', help='save results to project/name')
  135. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  136. opt = parser.parse_args()
  137. print(opt)
  138. with torch.no_grad():
  139. if opt.update: # update all models (to fix SourceChangeWarning)
  140. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  141. detect()
  142. strip_optimizer(opt.weights)
  143. else:
  144. detect()