You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

436 lines
18KB

  1. # Plotting utils
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import pandas as pd
  13. import seaborn as sns
  14. import torch
  15. import yaml
  16. from PIL import Image, ImageDraw, ImageFont
  17. from scipy.signal import butter, filtfilt
  18. from utils.general import xywh2xyxy, xyxy2xywh
  19. from utils.metrics import fitness
  20. # Settings
  21. matplotlib.rc('font', **{'size': 11})
  22. matplotlib.use('Agg') # for writing to files only
  23. def color_list():
  24. # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  25. def hex2rgb(h):
  26. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  27. return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
  28. def hist2d(x, y, n=100):
  29. # 2d histogram used in labels.png and evolve.png
  30. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  31. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  32. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  33. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  34. return np.log(hist[xidx, yidx])
  35. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  36. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  37. def butter_lowpass(cutoff, fs, order):
  38. nyq = 0.5 * fs
  39. normal_cutoff = cutoff / nyq
  40. return butter(order, normal_cutoff, btype='low', analog=False)
  41. b, a = butter_lowpass(cutoff, fs, order=order)
  42. return filtfilt(b, a, data) # forward-backward filter
  43. def plot_one_box(x, im, color=None, label=None, line_thickness=3):
  44. # Plots one bounding box on image 'im' using OpenCV
  45. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  46. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  47. color = color or [random.randint(0, 255) for _ in range(3)]
  48. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  49. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  50. if label:
  51. tf = max(tl - 1, 1) # font thickness
  52. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  53. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  54. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  55. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  56. def plot_one_box_PIL(box, im, color=None, label=None, line_thickness=None):
  57. # Plots one bounding box on image 'im' using PIL
  58. im = Image.fromarray(im)
  59. draw = ImageDraw.Draw(im)
  60. line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
  61. draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
  62. if label:
  63. fontsize = max(round(max(im.size) / 40), 12)
  64. font = ImageFont.truetype("Arial.ttf", fontsize)
  65. txt_width, txt_height = font.getsize(label)
  66. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
  67. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  68. return np.asarray(im)
  69. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  70. # Compares the two methods for width-height anchor multiplication
  71. # https://github.com/ultralytics/yolov3/issues/168
  72. x = np.arange(-4.0, 4.0, .1)
  73. ya = np.exp(x)
  74. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  75. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  76. plt.plot(x, ya, '.-', label='YOLOv3')
  77. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  78. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  79. plt.xlim(left=-4, right=4)
  80. plt.ylim(bottom=0, top=6)
  81. plt.xlabel('input')
  82. plt.ylabel('output')
  83. plt.grid()
  84. plt.legend()
  85. fig.savefig('comparison.png', dpi=200)
  86. def output_to_target(output):
  87. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  88. targets = []
  89. for i, o in enumerate(output):
  90. for *box, conf, cls in o.cpu().numpy():
  91. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  92. return np.array(targets)
  93. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  94. # Plot image grid with labels
  95. if isinstance(images, torch.Tensor):
  96. images = images.cpu().float().numpy()
  97. if isinstance(targets, torch.Tensor):
  98. targets = targets.cpu().numpy()
  99. # un-normalise
  100. if np.max(images[0]) <= 1:
  101. images *= 255
  102. tl = 3 # line thickness
  103. tf = max(tl - 1, 1) # font thickness
  104. bs, _, h, w = images.shape # batch size, _, height, width
  105. bs = min(bs, max_subplots) # limit plot images
  106. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  107. # Check if we should resize
  108. scale_factor = max_size / max(h, w)
  109. if scale_factor < 1:
  110. h = math.ceil(scale_factor * h)
  111. w = math.ceil(scale_factor * w)
  112. colors = color_list() # list of colors
  113. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  114. for i, img in enumerate(images):
  115. if i == max_subplots: # if last batch has fewer images than we expect
  116. break
  117. block_x = int(w * (i // ns))
  118. block_y = int(h * (i % ns))
  119. img = img.transpose(1, 2, 0)
  120. if scale_factor < 1:
  121. img = cv2.resize(img, (w, h))
  122. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  123. if len(targets) > 0:
  124. image_targets = targets[targets[:, 0] == i]
  125. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  126. classes = image_targets[:, 1].astype('int')
  127. labels = image_targets.shape[1] == 6 # labels if no conf column
  128. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  129. if boxes.shape[1]:
  130. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  131. boxes[[0, 2]] *= w # scale to pixels
  132. boxes[[1, 3]] *= h
  133. elif scale_factor < 1: # absolute coords need scale if image scales
  134. boxes *= scale_factor
  135. boxes[[0, 2]] += block_x
  136. boxes[[1, 3]] += block_y
  137. for j, box in enumerate(boxes.T):
  138. cls = int(classes[j])
  139. color = colors[cls % len(colors)]
  140. cls = names[cls] if names else cls
  141. if labels or conf[j] > 0.25: # 0.25 conf thresh
  142. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  143. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  144. # Draw image filename labels
  145. if paths:
  146. label = Path(paths[i]).name[:40] # trim to 40 char
  147. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  148. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  149. lineType=cv2.LINE_AA)
  150. # Image border
  151. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  152. if fname:
  153. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  154. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  155. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  156. Image.fromarray(mosaic).save(fname) # PIL save
  157. return mosaic
  158. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  159. # Plot LR simulating training for full epochs
  160. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  161. y = []
  162. for _ in range(epochs):
  163. scheduler.step()
  164. y.append(optimizer.param_groups[0]['lr'])
  165. plt.plot(y, '.-', label='LR')
  166. plt.xlabel('epoch')
  167. plt.ylabel('LR')
  168. plt.grid()
  169. plt.xlim(0, epochs)
  170. plt.ylim(0)
  171. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  172. plt.close()
  173. def plot_test_txt(): # from utils.plots import *; plot_test()
  174. # Plot test.txt histograms
  175. x = np.loadtxt('test.txt', dtype=np.float32)
  176. box = xyxy2xywh(x[:, :4])
  177. cx, cy = box[:, 0], box[:, 1]
  178. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  179. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  180. ax.set_aspect('equal')
  181. plt.savefig('hist2d.png', dpi=300)
  182. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  183. ax[0].hist(cx, bins=600)
  184. ax[1].hist(cy, bins=600)
  185. plt.savefig('hist1d.png', dpi=200)
  186. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  187. # Plot targets.txt histograms
  188. x = np.loadtxt('targets.txt', dtype=np.float32).T
  189. s = ['x targets', 'y targets', 'width targets', 'height targets']
  190. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  191. ax = ax.ravel()
  192. for i in range(4):
  193. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  194. ax[i].legend()
  195. ax[i].set_title(s[i])
  196. plt.savefig('targets.jpg', dpi=200)
  197. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  198. # Plot study.txt generated by test.py
  199. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  200. # ax = ax.ravel()
  201. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  202. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  203. for f in sorted(Path(path).glob('study*.txt')):
  204. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  205. x = np.arange(y.shape[1]) if x is None else np.array(x)
  206. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  207. # for i in range(7):
  208. # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  209. # ax[i].set_title(s[i])
  210. j = y[3].argmax() + 1
  211. ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  212. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  213. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  214. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  215. ax2.grid(alpha=0.2)
  216. ax2.set_yticks(np.arange(20, 60, 5))
  217. ax2.set_xlim(0, 57)
  218. ax2.set_ylim(30, 55)
  219. ax2.set_xlabel('GPU Speed (ms/img)')
  220. ax2.set_ylabel('COCO AP val')
  221. ax2.legend(loc='lower right')
  222. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  223. def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
  224. # plot dataset labels
  225. print('Plotting labels... ')
  226. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  227. nc = int(c.max() + 1) # number of classes
  228. colors = color_list()
  229. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  230. # seaborn correlogram
  231. sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  232. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  233. plt.close()
  234. # matplotlib labels
  235. matplotlib.use('svg') # faster
  236. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  237. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  238. ax[0].set_ylabel('instances')
  239. if 0 < len(names) < 30:
  240. ax[0].set_xticks(range(len(names)))
  241. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  242. else:
  243. ax[0].set_xlabel('classes')
  244. sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  245. sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  246. # rectangles
  247. labels[:, 1:3] = 0.5 # center
  248. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  249. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  250. for cls, *box in labels[:1000]:
  251. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
  252. ax[1].imshow(img)
  253. ax[1].axis('off')
  254. for a in [0, 1, 2, 3]:
  255. for s in ['top', 'right', 'left', 'bottom']:
  256. ax[a].spines[s].set_visible(False)
  257. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  258. matplotlib.use('Agg')
  259. plt.close()
  260. # loggers
  261. for k, v in loggers.items() or {}:
  262. if k == 'wandb' and v:
  263. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
  264. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  265. # Plot hyperparameter evolution results in evolve.txt
  266. with open(yaml_file) as f:
  267. hyp = yaml.safe_load(f)
  268. x = np.loadtxt('evolve.txt', ndmin=2)
  269. f = fitness(x)
  270. # weights = (f - f.min()) ** 2 # for weighted results
  271. plt.figure(figsize=(10, 12), tight_layout=True)
  272. matplotlib.rc('font', **{'size': 8})
  273. for i, (k, v) in enumerate(hyp.items()):
  274. y = x[:, i + 7]
  275. # mu = (y * weights).sum() / weights.sum() # best weighted result
  276. mu = y[f.argmax()] # best single result
  277. plt.subplot(6, 5, i + 1)
  278. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  279. plt.plot(mu, f.max(), 'k+', markersize=15)
  280. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  281. if i % 5 != 0:
  282. plt.yticks([])
  283. print('%15s: %.3g' % (k, mu))
  284. plt.savefig('evolve.png', dpi=200)
  285. print('\nPlot saved as evolve.png')
  286. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  287. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  288. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  289. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  290. files = list(Path(save_dir).glob('frames*.txt'))
  291. for fi, f in enumerate(files):
  292. try:
  293. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  294. n = results.shape[1] # number of rows
  295. x = np.arange(start, min(stop, n) if stop else n)
  296. results = results[:, x]
  297. t = (results[0] - results[0].min()) # set t0=0s
  298. results[0] = x
  299. for i, a in enumerate(ax):
  300. if i < len(results):
  301. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  302. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  303. a.set_title(s[i])
  304. a.set_xlabel('time (s)')
  305. # if fi == len(files) - 1:
  306. # a.set_ylim(bottom=0)
  307. for side in ['top', 'right']:
  308. a.spines[side].set_visible(False)
  309. else:
  310. a.remove()
  311. except Exception as e:
  312. print('Warning: Plotting error for %s; %s' % (f, e))
  313. ax[1].legend()
  314. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  315. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  316. # Plot training 'results*.txt', overlaying train and val losses
  317. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  318. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  319. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  320. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  321. n = results.shape[1] # number of rows
  322. x = range(start, min(stop, n) if stop else n)
  323. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  324. ax = ax.ravel()
  325. for i in range(5):
  326. for j in [i, i + 5]:
  327. y = results[j, x]
  328. ax[i].plot(x, y, marker='.', label=s[j])
  329. # y_smooth = butter_lowpass_filtfilt(y)
  330. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  331. ax[i].set_title(t[i])
  332. ax[i].legend()
  333. ax[i].set_ylabel(f) if i == 0 else None # add filename
  334. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  335. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  336. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  337. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  338. ax = ax.ravel()
  339. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  340. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  341. if bucket:
  342. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  343. files = ['results%g.txt' % x for x in id]
  344. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  345. os.system(c)
  346. else:
  347. files = list(Path(save_dir).glob('results*.txt'))
  348. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  349. for fi, f in enumerate(files):
  350. try:
  351. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  352. n = results.shape[1] # number of rows
  353. x = range(start, min(stop, n) if stop else n)
  354. for i in range(10):
  355. y = results[i, x]
  356. if i in [0, 1, 2, 5, 6, 7]:
  357. y[y == 0] = np.nan # don't show zero loss values
  358. # y /= y[0] # normalize
  359. label = labels[fi] if len(labels) else f.stem
  360. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  361. ax[i].set_title(s[i])
  362. # if i in [5, 6, 7]: # share train and val loss y axes
  363. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  364. except Exception as e:
  365. print('Warning: Plotting error for %s; %s' % (f, e))
  366. ax[1].legend()
  367. fig.savefig(Path(save_dir) / 'results.png', dpi=200)