Browse Source

update check_img_size() for model strides

5.0
Glenn Jocher 4 years ago
parent
commit
a9dc0c2c29
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      detect.py
  2. +1
    -1
      test.py

+ 1
- 1
detect.py View File

# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse() # model.fuse()
model.to(device).eval() model.to(device).eval()
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size
if half: if half:
model.half() # to FP16 model.half() # to FP16


parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--augment', action='store_true', help='augmented inference')
opt = parser.parse_args() opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)
print(opt) print(opt)


with torch.no_grad(): with torch.no_grad():

+ 1
- 1
test.py View File

torch_utils.model_info(model) torch_utils.model_info(model)
model.fuse() model.fuse()
model.to(device) model.to(device)
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size


# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
# if device.type != 'cpu' and torch.cuda.device_count() > 1: # if device.type != 'cpu' and torch.cuda.device_count() > 1:
parser.add_argument('--merge', action='store_true', help='use Merge NMS') parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--verbose', action='store_true', help='report mAP by class')
opt = parser.parse_args() opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)
opt.save_json = opt.save_json or opt.data.endswith('coco.yaml') opt.save_json = opt.save_json or opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file opt.data = check_file(opt.data) # check file
print(opt) print(opt)

Loading…
Cancel
Save