Update `profile()` for CUDA Memory allocation (#4239)
* Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Cleanup
This commit is contained in:
parent
bceb57b910
commit
d8f18834a2
|
|
@ -1172,11 +1172,11 @@
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Profile\n",
|
"# Profile\n",
|
||||||
"from utils.torch_utils import profile \n",
|
"from utils.torch_utils import profile\n",
|
||||||
"\n",
|
"\n",
|
||||||
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
||||||
"m2 = torch.nn.SiLU()\n",
|
"m2 = torch.nn.SiLU()\n",
|
||||||
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
|
"results = profile(input=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
|
||||||
],
|
],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
|
|
|
||||||
|
|
@ -98,42 +98,56 @@ def time_sync():
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
def profile(x, ops, n=100, device=None):
|
def profile(input, ops, n=10, device=None):
|
||||||
# profile a pytorch module or list of modules. Example usage:
|
# YOLOv5 speed/memory/FLOPs profiler
|
||||||
# x = torch.randn(16, 3, 640, 640) # input
|
#
|
||||||
|
# Usage:
|
||||||
|
# input = torch.randn(16, 3, 640, 640)
|
||||||
# m1 = lambda x: x * torch.sigmoid(x)
|
# m1 = lambda x: x * torch.sigmoid(x)
|
||||||
# m2 = nn.SiLU()
|
# m2 = nn.SiLU()
|
||||||
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
# profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||||
|
|
||||||
|
results = []
|
||||||
device = device or select_device()
|
device = device or select_device()
|
||||||
|
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||||
|
f"{'input':>24s}{'output':>24s}")
|
||||||
|
|
||||||
|
for x in input if isinstance(input, list) else [input]:
|
||||||
x = x.to(device)
|
x = x.to(device)
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
print(f"{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
|
||||||
for m in ops if isinstance(ops, list) else [ops]:
|
for m in ops if isinstance(ops, list) else [ops]:
|
||||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
|
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||||
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
||||||
try:
|
try:
|
||||||
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
||||||
except:
|
except:
|
||||||
flops = 0
|
flops = 0
|
||||||
|
|
||||||
|
try:
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
t[0] = time_sync()
|
t[0] = time_sync()
|
||||||
y = m(x)
|
y = m(x)
|
||||||
t[1] = time_sync()
|
t[1] = time_sync()
|
||||||
try:
|
try:
|
||||||
_ = y.sum().backward()
|
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
|
||||||
t[2] = time_sync()
|
t[2] = time_sync()
|
||||||
except: # no backward method
|
except Exception as e: # no backward method
|
||||||
|
print(e)
|
||||||
t[2] = float('nan')
|
t[2] = float('nan')
|
||||||
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||||
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||||
|
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||||
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
||||||
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
||||||
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
||||||
print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||||
|
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
results.append(None)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def is_parallel(model):
|
def is_parallel(model):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue