|
|
@@ -506,27 +506,27 @@ def check_dataset(data, autodownload=True): |
|
|
|
|
|
|
|
def check_amp(model): |
|
|
|
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation |
|
|
|
from models.common import AutoShape |
|
|
|
from models.common import AutoShape, DetectMultiBackend |
|
|
|
|
|
|
|
def amp_allclose(model, im): |
|
|
|
# All close FP32 vs AMP results |
|
|
|
m = AutoShape(model, verbose=False) # model |
|
|
|
a = m(im).xywhn[0] # FP32 inference |
|
|
|
m.amp = True |
|
|
|
b = m(im).xywhn[0] # AMP inference |
|
|
|
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance |
|
|
|
|
|
|
|
if next(model.parameters()).device.type == 'cpu': # get model device |
|
|
|
return False |
|
|
|
prefix = colorstr('AMP: ') |
|
|
|
file = ROOT / 'data' / 'images' / 'bus.jpg' # image to test |
|
|
|
if file.exists(): |
|
|
|
im = cv2.imread(file)[..., ::-1] # OpenCV image (BGR to RGB) |
|
|
|
elif check_online(): |
|
|
|
im = 'https://ultralytics.com/images/bus.jpg' |
|
|
|
else: |
|
|
|
LOGGER.warning(emojis(f'{prefix}checks skipped ⚠️, not online.')) |
|
|
|
return True |
|
|
|
m = AutoShape(model, verbose=False) # model |
|
|
|
a = m(im).xywhn[0] # FP32 inference |
|
|
|
m.amp = True |
|
|
|
b = m(im).xywhn[0] # AMP inference |
|
|
|
if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance |
|
|
|
device = next(model.parameters()).device # get model device |
|
|
|
if device.type == 'cpu': |
|
|
|
return False # AMP disabled on CPU |
|
|
|
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check |
|
|
|
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3)) |
|
|
|
try: |
|
|
|
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im) |
|
|
|
LOGGER.info(emojis(f'{prefix}checks passed ✅')) |
|
|
|
return True |
|
|
|
else: |
|
|
|
except Exception: |
|
|
|
help_url = 'https://github.com/ultralytics/yolov5/issues/7908' |
|
|
|
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')) |
|
|
|
return False |