|
|
@@ -1,12 +1,12 @@ |
|
|
|
# PyTorch utils |
|
|
|
|
|
|
|
import logging |
|
|
|
import math |
|
|
|
import os |
|
|
|
import time |
|
|
|
from contextlib import contextmanager |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
import torch.backends.cudnn as cudnn |
|
|
|
import torch.nn as nn |
|
|
@@ -152,7 +152,7 @@ def model_info(model, verbose=False, img_size=640): |
|
|
|
|
|
|
|
try: # FLOPS |
|
|
|
from thop import profile |
|
|
|
stride = int(model.stride.max()) |
|
|
|
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32 |
|
|
|
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input |
|
|
|
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS |
|
|
|
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float |