You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
7.5KB

  1. import argparse
  2. import os
  3. import shutil
  4. import time
  5. from pathlib import Path
  6. import cv2
  7. import torch
  8. import torch.backends.cudnn as cudnn
  9. from numpy import random
  10. from models.experimental import attempt_load
  11. from utils.datasets import LoadStreams, LoadImages
  12. from utils.general import (
  13. check_img_size, non_max_suppression, apply_classifier, scale_coords,
  14. xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
  15. from utils.torch_utils import select_device, load_classifier, time_synchronized
  16. def detect(save_img=False):
  17. out, source, weights, view_img, save_txt, imgsz = \
  18. opt.save_dir, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  19. webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt')
  20. # Initialize
  21. set_logging()
  22. device = select_device(opt.device)
  23. if os.path.exists(out): # output dir
  24. shutil.rmtree(out) # delete dir
  25. os.makedirs(out) # make new dir
  26. half = device.type != 'cpu' # half precision only supported on CUDA
  27. # Load model
  28. model = attempt_load(weights, map_location=device) # load FP32 model
  29. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  30. if half:
  31. model.half() # to FP16
  32. # Second-stage classifier
  33. classify = False
  34. if classify:
  35. modelc = load_classifier(name='resnet101', n=2) # initialize
  36. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
  37. modelc.to(device).eval()
  38. # Set Dataloader
  39. vid_path, vid_writer = None, None
  40. if webcam:
  41. view_img = True
  42. cudnn.benchmark = True # set True to speed up constant image size inference
  43. dataset = LoadStreams(source, img_size=imgsz)
  44. else:
  45. save_img = True
  46. dataset = LoadImages(source, img_size=imgsz)
  47. # Get names and colors
  48. names = model.module.names if hasattr(model, 'module') else model.names
  49. colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
  50. # Run inference
  51. t0 = time.time()
  52. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  53. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  54. for path, img, im0s, vid_cap in dataset:
  55. img = torch.from_numpy(img).to(device)
  56. img = img.half() if half else img.float() # uint8 to fp16/32
  57. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  58. if img.ndimension() == 3:
  59. img = img.unsqueeze(0)
  60. # Inference
  61. t1 = time_synchronized()
  62. pred = model(img, augment=opt.augment)[0]
  63. # Apply NMS
  64. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  65. t2 = time_synchronized()
  66. # Apply Classifier
  67. if classify:
  68. pred = apply_classifier(pred, modelc, img, im0s)
  69. # Process detections
  70. for i, det in enumerate(pred): # detections per image
  71. if webcam: # batch_size >= 1
  72. p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  73. else:
  74. p, s, im0 = path, '', im0s
  75. save_path = str(Path(out) / Path(p).name)
  76. txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  77. s += '%gx%g ' % img.shape[2:] # print string
  78. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  79. if det is not None and len(det):
  80. # Rescale boxes from img_size to im0 size
  81. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  82. # Print results
  83. for c in det[:, -1].unique():
  84. n = (det[:, -1] == c).sum() # detections per class
  85. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  86. # Write results
  87. for *xyxy, conf, cls in reversed(det):
  88. if save_txt: # Write to file
  89. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  90. line = (cls, conf, *xywh) if opt.save_conf else (cls, *xywh) # label format
  91. with open(txt_path + '.txt', 'a') as f:
  92. f.write(('%g ' * len(line) + '\n') % line)
  93. if save_img or view_img: # Add bbox to image
  94. label = '%s %.2f' % (names[int(cls)], conf)
  95. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  96. # Print time (inference + NMS)
  97. print('%sDone. (%.3fs)' % (s, t2 - t1))
  98. # Stream results
  99. if view_img:
  100. cv2.imshow(p, im0)
  101. if cv2.waitKey(1) == ord('q'): # q to quit
  102. raise StopIteration
  103. # Save results (image with detections)
  104. if save_img:
  105. if dataset.mode == 'images':
  106. cv2.imwrite(save_path, im0)
  107. else:
  108. if vid_path != save_path: # new video
  109. vid_path = save_path
  110. if isinstance(vid_writer, cv2.VideoWriter):
  111. vid_writer.release() # release previous video writer
  112. fourcc = 'mp4v' # output video codec
  113. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  114. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  115. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  116. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  117. vid_writer.write(im0)
  118. if save_txt or save_img:
  119. print('Results saved to %s' % Path(out))
  120. print('Done. (%.3fs)' % (time.time() - t0))
  121. if __name__ == '__main__':
  122. parser = argparse.ArgumentParser()
  123. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  124. parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
  125. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  126. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  127. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  128. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  129. parser.add_argument('--view-img', action='store_true', help='display results')
  130. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  131. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  132. parser.add_argument('--save-dir', type=str, default='inference/output', help='directory to save results')
  133. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  134. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  135. parser.add_argument('--augment', action='store_true', help='augmented inference')
  136. parser.add_argument('--update', action='store_true', help='update all models')
  137. opt = parser.parse_args()
  138. print(opt)
  139. with torch.no_grad():
  140. if opt.update: # update all models (to fix SourceChangeWarning)
  141. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  142. detect()
  143. strip_optimizer(opt.weights)
  144. else:
  145. detect()