You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import argparse
  2. import os
  3. import platform
  4. import shutil
  5. import time
  6. from pathlib import Path
  7. import cv2
  8. import torch
  9. import torch.backends.cudnn as cudnn
  10. from numpy import random
  11. from models.experimental import attempt_load
  12. from utils.datasets import LoadStreams, LoadImages
  13. from utils.general import (
  14. check_img_size, non_max_suppression, apply_classifier, scale_coords,
  15. xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
  16. from utils.torch_utils import select_device, load_classifier, time_synchronized
  17. def detect(save_img=False):
  18. out, source, weights, view_img, save_txt, imgsz = \
  19. opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  20. webcam = source.isnumeric() or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
  21. # Initialize
  22. set_logging()
  23. device = select_device(opt.device)
  24. if os.path.exists(out):
  25. shutil.rmtree(out) # delete output folder
  26. os.makedirs(out) # make new output folder
  27. half = device.type != 'cpu' # half precision only supported on CUDA
  28. # Load model
  29. model = attempt_load(weights, map_location=device) # load FP32 model
  30. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  31. if half:
  32. model.half() # to FP16
  33. # Second-stage classifier
  34. classify = False
  35. if classify:
  36. modelc = load_classifier(name='resnet101', n=2) # initialize
  37. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
  38. modelc.to(device).eval()
  39. # Set Dataloader
  40. vid_path, vid_writer = None, None
  41. if webcam:
  42. view_img = True
  43. cudnn.benchmark = True # set True to speed up constant image size inference
  44. dataset = LoadStreams(source, img_size=imgsz)
  45. else:
  46. save_img = True
  47. dataset = LoadImages(source, img_size=imgsz)
  48. # Get names and colors
  49. names = model.module.names if hasattr(model, 'module') else model.names
  50. colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
  51. # Run inference
  52. t0 = time.time()
  53. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  54. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  55. for path, img, im0s, vid_cap in dataset:
  56. img = torch.from_numpy(img).to(device)
  57. img = img.half() if half else img.float() # uint8 to fp16/32
  58. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  59. if img.ndimension() == 3:
  60. img = img.unsqueeze(0)
  61. # Inference
  62. t1 = time_synchronized()
  63. pred = model(img, augment=opt.augment)[0]
  64. # Apply NMS
  65. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  66. t2 = time_synchronized()
  67. # Apply Classifier
  68. if classify:
  69. pred = apply_classifier(pred, modelc, img, im0s)
  70. # Process detections
  71. for i, det in enumerate(pred): # detections per image
  72. if webcam: # batch_size >= 1
  73. p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  74. else:
  75. p, s, im0 = path, '', im0s
  76. save_path = str(Path(out) / Path(p).name)
  77. txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  78. s += '%gx%g ' % img.shape[2:] # print string
  79. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  80. if det is not None and len(det):
  81. # Rescale boxes from img_size to im0 size
  82. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  83. # Print results
  84. for c in det[:, -1].unique():
  85. n = (det[:, -1] == c).sum() # detections per class
  86. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  87. # Write results
  88. for *xyxy, conf, cls in reversed(det):
  89. if save_txt: # Write to file
  90. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  91. with open(txt_path + '.txt', 'a') as f:
  92. f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
  93. if save_img or view_img: # Add bbox to image
  94. label = '%s %.2f' % (names[int(cls)], conf)
  95. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  96. # Print time (inference + NMS)
  97. print('%sDone. (%.3fs)' % (s, t2 - t1))
  98. # Stream results
  99. if view_img:
  100. cv2.imshow(p, im0)
  101. if cv2.waitKey(1) == ord('q'): # q to quit
  102. raise StopIteration
  103. # Save results (image with detections)
  104. if save_img:
  105. if dataset.mode == 'images':
  106. cv2.imwrite(save_path, im0)
  107. else:
  108. if vid_path != save_path: # new video
  109. vid_path = save_path
  110. if isinstance(vid_writer, cv2.VideoWriter):
  111. vid_writer.release() # release previous video writer
  112. fourcc = 'mp4v' # output video codec
  113. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  114. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  115. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  116. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  117. vid_writer.write(im0)
  118. if save_txt or save_img:
  119. print('Results saved to %s' % Path(out))
  120. if platform.system() == 'Darwin' and not opt.update: # MacOS
  121. os.system('open ' + save_path)
  122. print('Done. (%.3fs)' % (time.time() - t0))
  123. if __name__ == '__main__':
  124. parser = argparse.ArgumentParser()
  125. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  126. parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
  127. parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder
  128. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  129. parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
  130. parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
  131. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  132. parser.add_argument('--view-img', action='store_true', help='display results')
  133. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  134. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  135. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  136. parser.add_argument('--augment', action='store_true', help='augmented inference')
  137. parser.add_argument('--update', action='store_true', help='update all models')
  138. opt = parser.parse_args()
  139. print(opt)
  140. with torch.no_grad():
  141. if opt.update: # update all models (to fix SourceChangeWarning)
  142. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  143. detect()
  144. strip_optimizer(opt.weights)
  145. else:
  146. detect()