Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

plots.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # Plotting utils
  2. import glob
  3. import os
  4. import random
  5. from copy import copy
  6. from pathlib import Path
  7. import cv2
  8. import math
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import torch
  13. import yaml
  14. from PIL import Image, ImageDraw
  15. from scipy.signal import butter, filtfilt
  16. from utils.general import xywh2xyxy, xyxy2xywh
  17. from utils.metrics import fitness
  18. # Settings
  19. matplotlib.rc('font', **{'size': 11})
  20. matplotlib.use('Agg') # for writing to files only
  21. def color_list():
  22. # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  23. def hex2rgb(h):
  24. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  25. return [hex2rgb(h) for h in plt.rcParams['axes.prop_cycle'].by_key()['color']]
  26. def hist2d(x, y, n=100):
  27. # 2d histogram used in labels.png and evolve.png
  28. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  29. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  30. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  31. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  32. return np.log(hist[xidx, yidx])
  33. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  34. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  35. def butter_lowpass(cutoff, fs, order):
  36. nyq = 0.5 * fs
  37. normal_cutoff = cutoff / nyq
  38. return butter(order, normal_cutoff, btype='low', analog=False)
  39. b, a = butter_lowpass(cutoff, fs, order=order)
  40. return filtfilt(b, a, data) # forward-backward filter
  41. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  42. # Plots one bounding box on image img
  43. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  44. color = color or [random.randint(0, 255) for _ in range(3)]
  45. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  46. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  47. if label:
  48. tf = max(tl - 1, 1) # font thickness
  49. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  50. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  51. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  52. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  53. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  54. # Compares the two methods for width-height anchor multiplication
  55. # https://github.com/ultralytics/yolov3/issues/168
  56. x = np.arange(-4.0, 4.0, .1)
  57. ya = np.exp(x)
  58. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  59. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  60. plt.plot(x, ya, '.-', label='YOLOv3')
  61. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  62. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  63. plt.xlim(left=-4, right=4)
  64. plt.ylim(bottom=0, top=6)
  65. plt.xlabel('input')
  66. plt.ylabel('output')
  67. plt.grid()
  68. plt.legend()
  69. fig.savefig('comparison.png', dpi=200)
  70. def output_to_target(output):
  71. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  72. targets = []
  73. for i, o in enumerate(output):
  74. for *box, conf, cls in o.cpu().numpy():
  75. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  76. return np.array(targets)
  77. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  78. # Plot image grid with labels
  79. if isinstance(images, torch.Tensor):
  80. images = images.cpu().float().numpy()
  81. if isinstance(targets, torch.Tensor):
  82. targets = targets.cpu().numpy()
  83. # un-normalise
  84. if np.max(images[0]) <= 1:
  85. images *= 255
  86. tl = 3 # line thickness
  87. tf = max(tl - 1, 1) # font thickness
  88. bs, _, h, w = images.shape # batch size, _, height, width
  89. bs = min(bs, max_subplots) # limit plot images
  90. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  91. # Check if we should resize
  92. scale_factor = max_size / max(h, w)
  93. if scale_factor < 1:
  94. h = math.ceil(scale_factor * h)
  95. w = math.ceil(scale_factor * w)
  96. colors = color_list() # list of colors
  97. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  98. for i, img in enumerate(images):
  99. if i == max_subplots: # if last batch has fewer images than we expect
  100. break
  101. block_x = int(w * (i // ns))
  102. block_y = int(h * (i % ns))
  103. img = img.transpose(1, 2, 0)
  104. if scale_factor < 1:
  105. img = cv2.resize(img, (w, h))
  106. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  107. if len(targets) > 0:
  108. image_targets = targets[targets[:, 0] == i]
  109. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  110. classes = image_targets[:, 1].astype('int')
  111. labels = image_targets.shape[1] == 6 # labels if no conf column
  112. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  113. if boxes.shape[1]:
  114. if boxes.max() <= 1: # if normalized
  115. boxes[[0, 2]] *= w # scale to pixels
  116. boxes[[1, 3]] *= h
  117. elif scale_factor < 1: # absolute coords need scale if image scales
  118. boxes *= scale_factor
  119. boxes[[0, 2]] += block_x
  120. boxes[[1, 3]] += block_y
  121. for j, box in enumerate(boxes.T):
  122. cls = int(classes[j])
  123. color = colors[cls % len(colors)]
  124. cls = names[cls] if names else cls
  125. if labels or conf[j] > 0.25: # 0.25 conf thresh
  126. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  127. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  128. # Draw image filename labels
  129. if paths:
  130. label = Path(paths[i]).name[:40] # trim to 40 char
  131. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  132. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  133. lineType=cv2.LINE_AA)
  134. # Image border
  135. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  136. if fname:
  137. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  138. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  139. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  140. Image.fromarray(mosaic).save(fname) # PIL save
  141. return mosaic
  142. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  143. # Plot LR simulating training for full epochs
  144. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  145. y = []
  146. for _ in range(epochs):
  147. scheduler.step()
  148. y.append(optimizer.param_groups[0]['lr'])
  149. plt.plot(y, '.-', label='LR')
  150. plt.xlabel('epoch')
  151. plt.ylabel('LR')
  152. plt.grid()
  153. plt.xlim(0, epochs)
  154. plt.ylim(0)
  155. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  156. def plot_test_txt(): # from utils.plots import *; plot_test()
  157. # Plot test.txt histograms
  158. x = np.loadtxt('test.txt', dtype=np.float32)
  159. box = xyxy2xywh(x[:, :4])
  160. cx, cy = box[:, 0], box[:, 1]
  161. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  162. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  163. ax.set_aspect('equal')
  164. plt.savefig('hist2d.png', dpi=300)
  165. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  166. ax[0].hist(cx, bins=600)
  167. ax[1].hist(cy, bins=600)
  168. plt.savefig('hist1d.png', dpi=200)
  169. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  170. # Plot targets.txt histograms
  171. x = np.loadtxt('targets.txt', dtype=np.float32).T
  172. s = ['x targets', 'y targets', 'width targets', 'height targets']
  173. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  174. ax = ax.ravel()
  175. for i in range(4):
  176. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  177. ax[i].legend()
  178. ax[i].set_title(s[i])
  179. plt.savefig('targets.jpg', dpi=200)
  180. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  181. # Plot study.txt generated by test.py
  182. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  183. ax = ax.ravel()
  184. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  185. for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s', 'yolov5m', 'yolov5l', 'yolov5x']]:
  186. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  187. x = np.arange(y.shape[1]) if x is None else np.array(x)
  188. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  189. for i in range(7):
  190. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  191. ax[i].set_title(s[i])
  192. j = y[3].argmax() + 1
  193. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  194. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  195. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  196. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  197. ax2.grid()
  198. ax2.set_xlim(0, 30)
  199. ax2.set_ylim(28, 50)
  200. ax2.set_yticks(np.arange(30, 55, 5))
  201. ax2.set_xlabel('GPU Speed (ms/img)')
  202. ax2.set_ylabel('COCO AP val')
  203. ax2.legend(loc='lower right')
  204. plt.savefig('test_study.png', dpi=300)
  205. def plot_labels(labels, save_dir=Path(''), loggers=None):
  206. # plot dataset labels
  207. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  208. nc = int(c.max() + 1) # number of classes
  209. colors = color_list()
  210. # seaborn correlogram
  211. try:
  212. import seaborn as sns
  213. import pandas as pd
  214. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  215. sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
  216. plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
  217. diag_kws=dict(bins=50))
  218. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  219. plt.close()
  220. except Exception as e:
  221. pass
  222. # matplotlib labels
  223. matplotlib.use('svg') # faster
  224. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  225. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  226. ax[0].set_xlabel('classes')
  227. ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  228. ax[2].set_xlabel('x')
  229. ax[2].set_ylabel('y')
  230. ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  231. ax[3].set_xlabel('width')
  232. ax[3].set_ylabel('height')
  233. # rectangles
  234. labels[:, 1:3] = 0.5 # center
  235. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  236. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  237. for cls, *box in labels[:1000]:
  238. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
  239. ax[1].imshow(img)
  240. ax[1].axis('off')
  241. for a in [0, 1, 2, 3]:
  242. for s in ['top', 'right', 'left', 'bottom']:
  243. ax[a].spines[s].set_visible(False)
  244. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  245. matplotlib.use('Agg')
  246. plt.close()
  247. # loggers
  248. for k, v in loggers.items() or {}:
  249. if k == 'wandb' and v:
  250. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]})
  251. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  252. # Plot hyperparameter evolution results in evolve.txt
  253. with open(yaml_file) as f:
  254. hyp = yaml.load(f, Loader=yaml.FullLoader)
  255. x = np.loadtxt('evolve.txt', ndmin=2)
  256. f = fitness(x)
  257. # weights = (f - f.min()) ** 2 # for weighted results
  258. plt.figure(figsize=(10, 12), tight_layout=True)
  259. matplotlib.rc('font', **{'size': 8})
  260. for i, (k, v) in enumerate(hyp.items()):
  261. y = x[:, i + 7]
  262. # mu = (y * weights).sum() / weights.sum() # best weighted result
  263. mu = y[f.argmax()] # best single result
  264. plt.subplot(6, 5, i + 1)
  265. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  266. plt.plot(mu, f.max(), 'k+', markersize=15)
  267. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  268. if i % 5 != 0:
  269. plt.yticks([])
  270. print('%15s: %.3g' % (k, mu))
  271. plt.savefig('evolve.png', dpi=200)
  272. print('\nPlot saved as evolve.png')
  273. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  274. # Plot training 'results*.txt', overlaying train and val losses
  275. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  276. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  277. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  278. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  279. n = results.shape[1] # number of rows
  280. x = range(start, min(stop, n) if stop else n)
  281. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  282. ax = ax.ravel()
  283. for i in range(5):
  284. for j in [i, i + 5]:
  285. y = results[j, x]
  286. ax[i].plot(x, y, marker='.', label=s[j])
  287. # y_smooth = butter_lowpass_filtfilt(y)
  288. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  289. ax[i].set_title(t[i])
  290. ax[i].legend()
  291. ax[i].set_ylabel(f) if i == 0 else None # add filename
  292. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  293. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  294. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  295. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  296. ax = ax.ravel()
  297. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  298. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  299. if bucket:
  300. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  301. files = ['results%g.txt' % x for x in id]
  302. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  303. os.system(c)
  304. else:
  305. files = list(Path(save_dir).glob('results*.txt'))
  306. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  307. for fi, f in enumerate(files):
  308. try:
  309. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  310. n = results.shape[1] # number of rows
  311. x = range(start, min(stop, n) if stop else n)
  312. for i in range(10):
  313. y = results[i, x]
  314. if i in [0, 1, 2, 5, 6, 7]:
  315. y[y == 0] = np.nan # don't show zero loss values
  316. # y /= y[0] # normalize
  317. label = labels[fi] if len(labels) else f.stem
  318. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  319. ax[i].set_title(s[i])
  320. # if i in [5, 6, 7]: # share train and val loss y axes
  321. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  322. except Exception as e:
  323. print('Warning: Plotting error for %s; %s' % (f, e))
  324. ax[1].legend()
  325. fig.savefig(Path(save_dir) / 'results.png', dpi=200)