Browse Source

Increase FLOPS robustness (#1608)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
8918e63476
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      utils/torch_utils.py

+ 2
- 2
utils/torch_utils.py View File

@@ -1,12 +1,12 @@
# PyTorch utils

import logging
import math
import os
import time
from contextlib import contextmanager
from copy import deepcopy

import math
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
@@ -152,7 +152,7 @@ def model_info(model, verbose=False, img_size=640):

try: # FLOPS
from thop import profile
stride = int(model.stride.max())
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float

Loading…
Cancel
Save