Selaa lähdekoodia

AMP check improvements backup YOLOv5n pretrained (#7959)

* Reduce AMP check to detections verification

More robust and faster

* Update general.py

* Update general.py
modifyDataloader
Glenn Jocher GitHub 2 vuotta sitten
vanhempi
commit
d07f9ce0ea
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 muutettua tiedostoa jossa 17 lisäystä ja 17 poistoa
  1. +17
    -17
      utils/general.py

+ 17
- 17
utils/general.py Näytä tiedosto

@@ -506,27 +506,27 @@ def check_dataset(data, autodownload=True):

def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape
from models.common import AutoShape, DetectMultiBackend

def amp_allclose(model, im):
# All close FP32 vs AMP results
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance

if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
file = ROOT / 'data' / 'images' / 'bus.jpg' # image to test
if file.exists():
im = cv2.imread(file)[..., ::-1] # OpenCV image (BGR to RGB)
elif check_online():
im = 'https://ultralytics.com/images/bus.jpg'
else:
LOGGER.warning(emojis(f'{prefix}checks skipped ⚠️, not online.'))
return True
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
return False # AMP disabled on CPU
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
try:
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
return True
else:
except Exception:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False

Loading…
Peruuta
Tallenna