Implement `@torch.no_grad()` decorator (#3312)

* `@torch.no_grad()` decorator

* Update detect.py
This commit is contained in:
Glenn Jocher 2021-05-24 13:23:09 +02:00 committed by GitHub
parent 73a92dc1b6
commit 61ea23c3fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 20 deletions

View File

@ -14,6 +14,7 @@ from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
@torch.no_grad()
def detect(opt): def detect(opt):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images save_img = not opt.nosave and not source.endswith('.txt') # save inference images
@ -175,7 +176,6 @@ if __name__ == '__main__':
print(opt) print(opt)
check_requirements(exclude=('tensorboard', 'pycocotools', 'thop')) check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
with torch.no_grad():
if opt.update: # update all models (to fix SourceChangeWarning) if opt.update: # update all models (to fix SourceChangeWarning)
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
detect(opt=opt) detect(opt=opt)

View File

@ -18,6 +18,7 @@ from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_synchronized from utils.torch_utils import select_device, time_synchronized
@torch.no_grad()
def test(data, def test(data,
weights=None, weights=None,
batch_size=32, batch_size=32,
@ -105,7 +106,6 @@ def test(data,
targets = targets.to(device) targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width nb, _, height, width = img.shape # batch size, channels, height, width
with torch.no_grad():
# Run model # Run model
t = time_synchronized() t = time_synchronized()
out, train_out = model(img, augment=augment) # inference and training outputs out, train_out = model(img, augment=augment) # inference and training outputs