Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

185 lines
9.2KB

  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from models.experimental import attempt_load
  8. from utils.datasets import LoadStreams, LoadImages
  9. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  10. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
  11. from utils.plots import colors, plot_one_box
  12. from utils.torch_utils import select_device, load_classifier, time_synchronized
  13. def detect(opt):
  14. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  15. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  16. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  17. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  18. # Directories
  19. save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
  20. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  21. # Initialize
  22. set_logging()
  23. device = select_device(opt.device)
  24. half = device.type != 'cpu' # half precision only supported on CUDA
  25. # Load model
  26. model = attempt_load(weights, map_location=device) # load FP32 model
  27. stride = int(model.stride.max()) # model stride
  28. imgsz = check_img_size(imgsz, s=stride) # check img_size
  29. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  30. if half:
  31. model.half() # to FP16
  32. # Second-stage classifier
  33. classify = False
  34. if classify:
  35. modelc = load_classifier(name='resnet101', n=2) # initialize
  36. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  37. # Set Dataloader
  38. vid_path, vid_writer = None, None
  39. if webcam:
  40. view_img = check_imshow()
  41. cudnn.benchmark = True # set True to speed up constant image size inference
  42. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  43. else:
  44. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  45. # Run inference
  46. if device.type != 'cpu':
  47. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  48. t0 = time.time()
  49. for path, img, im0s, vid_cap in dataset:
  50. img = torch.from_numpy(img).to(device)
  51. img = img.half() if half else img.float() # uint8 to fp16/32
  52. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  53. if img.ndimension() == 3:
  54. img = img.unsqueeze(0)
  55. # Inference
  56. t1 = time_synchronized()
  57. pred = model(img, augment=opt.augment)[0]
  58. # Apply NMS
  59. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms,
  60. max_det=opt.max_det)
  61. t2 = time_synchronized()
  62. # Apply Classifier
  63. if classify:
  64. pred = apply_classifier(pred, modelc, img, im0s)
  65. # Process detections
  66. for i, det in enumerate(pred): # detections per image
  67. if webcam: # batch_size >= 1
  68. p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
  69. else:
  70. p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
  71. p = Path(p) # to Path
  72. save_path = str(save_dir / p.name) # img.jpg
  73. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  74. s += '%gx%g ' % img.shape[2:] # print string
  75. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  76. imc = im0.copy() if opt.save_crop else im0 # for opt.save_crop
  77. if len(det):
  78. # Rescale boxes from img_size to im0 size
  79. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  80. # Print results
  81. for c in det[:, -1].unique():
  82. n = (det[:, -1] == c).sum() # detections per class
  83. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  84. # Write results
  85. for *xyxy, conf, cls in reversed(det):
  86. if save_txt: # Write to file
  87. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  88. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  89. with open(txt_path + '.txt', 'a') as f:
  90. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  91. if save_img or opt.save_crop or view_img: # Add bbox to image
  92. c = int(cls) # integer class
  93. label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
  94. plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
  95. if opt.save_crop:
  96. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
  97. # Print time (inference + NMS)
  98. print(f'{s}Done. ({t2 - t1:.3f}s)')
  99. # Stream results
  100. if view_img:
  101. cv2.imshow(str(p), im0)
  102. cv2.waitKey(1) # 1 millisecond
  103. # Save results (image with detections)
  104. if save_img:
  105. if dataset.mode == 'image':
  106. cv2.imwrite(save_path, im0)
  107. else: # 'video' or 'stream'
  108. if vid_path != save_path: # new video
  109. vid_path = save_path
  110. if isinstance(vid_writer, cv2.VideoWriter):
  111. vid_writer.release() # release previous video writer
  112. if vid_cap: # video
  113. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  114. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  115. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  116. else: # stream
  117. fps, w, h = 30, im0.shape[1], im0.shape[0]
  118. save_path += '.mp4'
  119. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  120. vid_writer.write(im0)
  121. if save_txt or save_img:
  122. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  123. print(f"Results saved to {save_dir}{s}")
  124. print(f'Done. ({time.time() - t0:.3f}s)')
  125. if __name__ == '__main__':
  126. parser = argparse.ArgumentParser()
  127. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  128. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  129. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  130. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  131. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  132. parser.add_argument('--max-det', type=int, default=1000, help='maximum number of detections per image')
  133. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  134. parser.add_argument('--view-img', action='store_true', help='display results')
  135. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  136. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  137. parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
  138. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  139. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  140. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  141. parser.add_argument('--augment', action='store_true', help='augmented inference')
  142. parser.add_argument('--update', action='store_true', help='update all models')
  143. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  144. parser.add_argument('--name', default='exp', help='save results to project/name')
  145. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  146. parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
  147. parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
  148. parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
  149. opt = parser.parse_args()
  150. print(opt)
  151. check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
  152. with torch.no_grad():
  153. if opt.update: # update all models (to fix SourceChangeWarning)
  154. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  155. detect(opt=opt)
  156. strip_optimizer(opt.weights)
  157. else:
  158. detect(opt=opt)