torch.cuda.amp bug fix (#2750)
PR https://github.com/ultralytics/yolov5/pull/2725 introduced a very specific bug that only affects multi-GPU trainings. Apparently the cause was using the torch.cuda.amp decorator in the autoShape forward method. I've implemented amp more traditionally in this PR, and the bug is resolved.
This commit is contained in:
parent
fca5e2a48f
commit
b5de52c4cd
|
|
@ -10,6 +10,7 @@ import requests
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from torch.cuda import amp
|
||||
|
||||
from utils.datasets import letterbox
|
||||
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
|
||||
|
|
@ -237,7 +238,6 @@ class autoShape(nn.Module):
|
|||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.cuda.amp.autocast(torch.cuda.is_available())
|
||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
||||
# filename: imgs = 'data/samples/zidane.jpg'
|
||||
|
|
@ -251,7 +251,8 @@ class autoShape(nn.Module):
|
|||
t = [time_synchronized()]
|
||||
p = next(self.model.parameters()) # for device and type
|
||||
if isinstance(imgs, torch.Tensor): # torch
|
||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||
with amp.autocast(enabled=p.device.type != 'cpu'):
|
||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
||||
|
||||
# Pre-process
|
||||
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
|
||||
|
|
@ -278,17 +279,18 @@ class autoShape(nn.Module):
|
|||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
||||
t.append(time_synchronized())
|
||||
|
||||
# Inference
|
||||
y = self.model(x, augment, profile)[0] # forward
|
||||
t.append(time_synchronized())
|
||||
with amp.autocast(enabled=p.device.type != 'cpu'):
|
||||
# Inference
|
||||
y = self.model(x, augment, profile)[0] # forward
|
||||
t.append(time_synchronized())
|
||||
|
||||
# Post-process
|
||||
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
||||
for i in range(n):
|
||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||
# Post-process
|
||||
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
||||
for i in range(n):
|
||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||
|
||||
t.append(time_synchronized())
|
||||
return Detections(imgs, y, files, t, self.names, x.shape)
|
||||
t.append(time_synchronized())
|
||||
return Detections(imgs, y, files, t, self.names, x.shape)
|
||||
|
||||
|
||||
class Detections:
|
||||
|
|
|
|||
Loading…
Reference in New Issue