You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
7.8KB

  1. import argparse
  2. import os
  3. import time
  4. from pathlib import Path
  5. import cv2
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. from numpy import random
  9. from models.experimental import attempt_load
  10. from utils.datasets import LoadStreams, LoadImages
  11. from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
  12. plot_one_box, strip_optimizer, set_logging, increment_dir
  13. from utils.torch_utils import select_device, load_classifier, time_synchronized
  14. def detect(save_img=False):
  15. save_dir, source, weights, view_img, save_txt, imgsz = \
  16. Path(opt.save_dir), opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  17. webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt')
  18. # Directories
  19. if save_dir == Path('runs/detect'): # if default
  20. os.makedirs('runs/detect', exist_ok=True) # make base
  21. save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run
  22. os.makedirs(save_dir / 'labels' if save_txt else save_dir, exist_ok=True) # make new dir
  23. # Initialize
  24. set_logging()
  25. device = select_device(opt.device)
  26. half = device.type != 'cpu' # half precision only supported on CUDA
  27. # Load model
  28. model = attempt_load(weights, map_location=device) # load FP32 model
  29. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  30. if half:
  31. model.half() # to FP16
  32. # Second-stage classifier
  33. classify = False
  34. if classify:
  35. modelc = load_classifier(name='resnet101', n=2) # initialize
  36. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
  37. modelc.to(device).eval()
  38. # Set Dataloader
  39. vid_path, vid_writer = None, None
  40. if webcam:
  41. view_img = True
  42. cudnn.benchmark = True # set True to speed up constant image size inference
  43. dataset = LoadStreams(source, img_size=imgsz)
  44. else:
  45. save_img = True
  46. dataset = LoadImages(source, img_size=imgsz)
  47. # Get names and colors
  48. names = model.module.names if hasattr(model, 'module') else model.names
  49. colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
  50. # Run inference
  51. t0 = time.time()
  52. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  53. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  54. for path, img, im0s, vid_cap in dataset:
  55. img = torch.from_numpy(img).to(device)
  56. img = img.half() if half else img.float() # uint8 to fp16/32
  57. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  58. if img.ndimension() == 3:
  59. img = img.unsqueeze(0)
  60. # Inference
  61. t1 = time_synchronized()
  62. pred = model(img, augment=opt.augment)[0]
  63. # Apply NMS
  64. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  65. t2 = time_synchronized()
  66. # Apply Classifier
  67. if classify:
  68. pred = apply_classifier(pred, modelc, img, im0s)
  69. # Process detections
  70. for i, det in enumerate(pred): # detections per image
  71. if webcam: # batch_size >= 1
  72. p, s, im0 = Path(path[i]), '%g: ' % i, im0s[i].copy()
  73. else:
  74. p, s, im0 = Path(path), '', im0s
  75. save_path = str(save_dir / p.name)
  76. txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  77. s += '%gx%g ' % img.shape[2:] # print string
  78. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  79. if det is not None and len(det):
  80. # Rescale boxes from img_size to im0 size
  81. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  82. # Print results
  83. for c in det[:, -1].unique():
  84. n = (det[:, -1] == c).sum() # detections per class
  85. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  86. # Write results
  87. for *xyxy, conf, cls in reversed(det):
  88. if save_txt: # Write to file
  89. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  90. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  91. with open(txt_path + '.txt', 'a') as f:
  92. f.write(('%g ' * len(line) + '\n') % line)
  93. if save_img or view_img: # Add bbox to image
  94. label = '%s %.2f' % (names[int(cls)], conf)
  95. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  96. # Print time (inference + NMS)
  97. print('%sDone. (%.3fs)' % (s, t2 - t1))
  98. # Stream results
  99. if view_img:
  100. cv2.imshow(p, im0)
  101. if cv2.waitKey(1) == ord('q'): # q to quit
  102. raise StopIteration
  103. # Save results (image with detections)
  104. if save_img:
  105. if dataset.mode == 'images':
  106. cv2.imwrite(save_path, im0)
  107. else:
  108. if vid_path != save_path: # new video
  109. vid_path = save_path
  110. if isinstance(vid_writer, cv2.VideoWriter):
  111. vid_writer.release() # release previous video writer
  112. fourcc = 'mp4v' # output video codec
  113. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  114. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  115. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  116. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  117. vid_writer.write(im0)
  118. if save_txt or save_img:
  119. print('Results saved to %s' % save_dir)
  120. print('Done. (%.3fs)' % (time.time() - t0))
  121. if __name__ == '__main__':
  122. parser = argparse.ArgumentParser()
  123. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  124. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  125. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  126. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  127. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  128. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  129. parser.add_argument('--view-img', action='store_true', help='display results')
  130. parser.add_argument('--save-txt', action='store_false', help='save results to *.txt')
  131. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  132. parser.add_argument('--save-dir', type=str, default='runs/detect', help='directory to save results')
  133. parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}')
  134. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  135. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  136. parser.add_argument('--augment', action='store_true', help='augmented inference')
  137. parser.add_argument('--update', action='store_true', help='update all models')
  138. opt = parser.parse_args()
  139. print(opt)
  140. with torch.no_grad():
  141. if opt.update: # update all models (to fix SourceChangeWarning)
  142. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  143. detect()
  144. strip_optimizer(opt.weights)
  145. else:
  146. detect()