From 05a955a3f6a580307f744178a9ad6938cb326096 Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Thu, 19 Nov 2020 19:56:20 +0800 Subject: [PATCH] FLOPS computation device bug fix (#1447) * Update torch_utils.py fix issue#113 , inputs device should be same with model parameters' device * Update torch_utils.py * Update torch_utils.py Co-authored-by: Glenn Jocher --- utils/torch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index a6c460f..b330ca5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -153,7 +153,8 @@ def model_info(model, verbose=False, img_size=640): try: # FLOPS from thop import profile stride = int(model.stride.max()) - flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2 + img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input + flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS except (ImportError, Exception):