You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

379 lines
15KB

  1. # Plotting utils
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import torch
  13. import yaml
  14. from PIL import Image
  15. from scipy.signal import butter, filtfilt
  16. from utils.general import xywh2xyxy, xyxy2xywh
  17. from utils.metrics import fitness
  18. # Settings
  19. matplotlib.use('Agg') # for writing to files only
  20. def color_list():
  21. # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  22. def hex2rgb(h):
  23. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  24. return [hex2rgb(h) for h in plt.rcParams['axes.prop_cycle'].by_key()['color']]
  25. def hist2d(x, y, n=100):
  26. # 2d histogram used in labels.png and evolve.png
  27. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  28. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  29. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  30. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  31. return np.log(hist[xidx, yidx])
  32. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  33. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  34. def butter_lowpass(cutoff, fs, order):
  35. nyq = 0.5 * fs
  36. normal_cutoff = cutoff / nyq
  37. return butter(order, normal_cutoff, btype='low', analog=False)
  38. b, a = butter_lowpass(cutoff, fs, order=order)
  39. return filtfilt(b, a, data) # forward-backward filter
  40. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  41. # Plots one bounding box on image img
  42. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  43. color = color or [random.randint(0, 255) for _ in range(3)]
  44. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  45. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  46. if label:
  47. tf = max(tl - 1, 1) # font thickness
  48. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  49. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  50. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  51. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  52. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  53. # Compares the two methods for width-height anchor multiplication
  54. # https://github.com/ultralytics/yolov3/issues/168
  55. x = np.arange(-4.0, 4.0, .1)
  56. ya = np.exp(x)
  57. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  58. fig = plt.figure(figsize=(6, 3), dpi=150)
  59. plt.plot(x, ya, '.-', label='YOLOv3')
  60. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  61. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  62. plt.xlim(left=-4, right=4)
  63. plt.ylim(bottom=0, top=6)
  64. plt.xlabel('input')
  65. plt.ylabel('output')
  66. plt.grid()
  67. plt.legend()
  68. fig.tight_layout()
  69. fig.savefig('comparison.png', dpi=200)
  70. def output_to_target(output, width, height):
  71. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  72. if isinstance(output, torch.Tensor):
  73. output = output.cpu().numpy()
  74. targets = []
  75. for i, o in enumerate(output):
  76. if o is not None:
  77. for pred in o:
  78. box = pred[:4]
  79. w = (box[2] - box[0]) / width
  80. h = (box[3] - box[1]) / height
  81. x = box[0] / width + w / 2
  82. y = box[1] / height + h / 2
  83. conf = pred[4]
  84. cls = int(pred[5])
  85. targets.append([i, cls, x, y, w, h, conf])
  86. return np.array(targets)
  87. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  88. # Plot image grid with labels
  89. if isinstance(images, torch.Tensor):
  90. images = images.cpu().float().numpy()
  91. if isinstance(targets, torch.Tensor):
  92. targets = targets.cpu().numpy()
  93. # un-normalise
  94. if np.max(images[0]) <= 1:
  95. images *= 255
  96. tl = 3 # line thickness
  97. tf = max(tl - 1, 1) # font thickness
  98. bs, _, h, w = images.shape # batch size, _, height, width
  99. bs = min(bs, max_subplots) # limit plot images
  100. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  101. # Check if we should resize
  102. scale_factor = max_size / max(h, w)
  103. if scale_factor < 1:
  104. h = math.ceil(scale_factor * h)
  105. w = math.ceil(scale_factor * w)
  106. colors = color_list() # list of colors
  107. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  108. for i, img in enumerate(images):
  109. if i == max_subplots: # if last batch has fewer images than we expect
  110. break
  111. block_x = int(w * (i // ns))
  112. block_y = int(h * (i % ns))
  113. img = img.transpose(1, 2, 0)
  114. if scale_factor < 1:
  115. img = cv2.resize(img, (w, h))
  116. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  117. if len(targets) > 0:
  118. image_targets = targets[targets[:, 0] == i]
  119. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  120. classes = image_targets[:, 1].astype('int')
  121. labels = image_targets.shape[1] == 6 # labels if no conf column
  122. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  123. boxes[[0, 2]] *= w
  124. boxes[[0, 2]] += block_x
  125. boxes[[1, 3]] *= h
  126. boxes[[1, 3]] += block_y
  127. for j, box in enumerate(boxes.T):
  128. cls = int(classes[j])
  129. color = colors[cls % len(colors)]
  130. cls = names[cls] if names else cls
  131. if labels or conf[j] > 0.25: # 0.25 conf thresh
  132. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  133. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  134. # Draw image filename labels
  135. if paths:
  136. label = Path(paths[i]).name[:40] # trim to 40 char
  137. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  138. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  139. lineType=cv2.LINE_AA)
  140. # Image border
  141. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  142. if fname:
  143. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  144. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  145. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  146. Image.fromarray(mosaic).save(fname) # PIL save
  147. return mosaic
  148. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  149. # Plot LR simulating training for full epochs
  150. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  151. y = []
  152. for _ in range(epochs):
  153. scheduler.step()
  154. y.append(optimizer.param_groups[0]['lr'])
  155. plt.plot(y, '.-', label='LR')
  156. plt.xlabel('epoch')
  157. plt.ylabel('LR')
  158. plt.grid()
  159. plt.xlim(0, epochs)
  160. plt.ylim(0)
  161. plt.tight_layout()
  162. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  163. def plot_test_txt(): # from utils.plots import *; plot_test()
  164. # Plot test.txt histograms
  165. x = np.loadtxt('test.txt', dtype=np.float32)
  166. box = xyxy2xywh(x[:, :4])
  167. cx, cy = box[:, 0], box[:, 1]
  168. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  169. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  170. ax.set_aspect('equal')
  171. plt.savefig('hist2d.png', dpi=300)
  172. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  173. ax[0].hist(cx, bins=600)
  174. ax[1].hist(cy, bins=600)
  175. plt.savefig('hist1d.png', dpi=200)
  176. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  177. # Plot targets.txt histograms
  178. x = np.loadtxt('targets.txt', dtype=np.float32).T
  179. s = ['x targets', 'y targets', 'width targets', 'height targets']
  180. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  181. ax = ax.ravel()
  182. for i in range(4):
  183. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  184. ax[i].legend()
  185. ax[i].set_title(s[i])
  186. plt.savefig('targets.jpg', dpi=200)
  187. def plot_study_txt(f='study.txt', x=None): # from utils.plots import *; plot_study_txt()
  188. # Plot study.txt generated by test.py
  189. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  190. ax = ax.ravel()
  191. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  192. for f in ['study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  193. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  194. x = np.arange(y.shape[1]) if x is None else np.array(x)
  195. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  196. for i in range(7):
  197. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  198. ax[i].set_title(s[i])
  199. j = y[3].argmax() + 1
  200. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  201. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  202. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  203. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  204. ax2.grid()
  205. ax2.set_xlim(0, 30)
  206. ax2.set_ylim(28, 50)
  207. ax2.set_yticks(np.arange(30, 55, 5))
  208. ax2.set_xlabel('GPU Speed (ms/img)')
  209. ax2.set_ylabel('COCO AP val')
  210. ax2.legend(loc='lower right')
  211. plt.savefig('study_mAP_latency.png', dpi=300)
  212. plt.savefig(f.replace('.txt', '.png'), dpi=300)
  213. def plot_labels(labels, save_dir=''):
  214. # plot dataset labels
  215. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  216. nc = int(c.max() + 1) # number of classes
  217. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  218. ax = ax.ravel()
  219. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  220. ax[0].set_xlabel('classes')
  221. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  222. ax[1].set_xlabel('x')
  223. ax[1].set_ylabel('y')
  224. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  225. ax[2].set_xlabel('width')
  226. ax[2].set_ylabel('height')
  227. plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
  228. plt.close()
  229. # seaborn correlogram
  230. try:
  231. import seaborn as sns
  232. import pandas as pd
  233. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  234. sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
  235. plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
  236. diag_kws=dict(bins=50))
  237. plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
  238. plt.close()
  239. except Exception as e:
  240. pass
  241. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  242. # Plot hyperparameter evolution results in evolve.txt
  243. with open(yaml_file) as f:
  244. hyp = yaml.load(f, Loader=yaml.FullLoader)
  245. x = np.loadtxt('evolve.txt', ndmin=2)
  246. f = fitness(x)
  247. # weights = (f - f.min()) ** 2 # for weighted results
  248. plt.figure(figsize=(10, 12), tight_layout=True)
  249. matplotlib.rc('font', **{'size': 8})
  250. for i, (k, v) in enumerate(hyp.items()):
  251. y = x[:, i + 7]
  252. # mu = (y * weights).sum() / weights.sum() # best weighted result
  253. mu = y[f.argmax()] # best single result
  254. plt.subplot(6, 5, i + 1)
  255. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  256. plt.plot(mu, f.max(), 'k+', markersize=15)
  257. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  258. if i % 5 != 0:
  259. plt.yticks([])
  260. print('%15s: %.3g' % (k, mu))
  261. plt.savefig('evolve.png', dpi=200)
  262. print('\nPlot saved as evolve.png')
  263. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  264. # Plot training 'results*.txt', overlaying train and val losses
  265. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  266. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  267. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  268. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  269. n = results.shape[1] # number of rows
  270. x = range(start, min(stop, n) if stop else n)
  271. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  272. ax = ax.ravel()
  273. for i in range(5):
  274. for j in [i, i + 5]:
  275. y = results[j, x]
  276. ax[i].plot(x, y, marker='.', label=s[j])
  277. # y_smooth = butter_lowpass_filtfilt(y)
  278. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  279. ax[i].set_title(t[i])
  280. ax[i].legend()
  281. ax[i].set_ylabel(f) if i == 0 else None # add filename
  282. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  283. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  284. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  285. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  286. ax = ax.ravel()
  287. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  288. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  289. if bucket:
  290. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  291. files = ['results%g.txt' % x for x in id]
  292. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  293. os.system(c)
  294. else:
  295. files = list(Path(save_dir).glob('results*.txt'))
  296. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  297. for fi, f in enumerate(files):
  298. try:
  299. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  300. n = results.shape[1] # number of rows
  301. x = range(start, min(stop, n) if stop else n)
  302. for i in range(10):
  303. y = results[i, x]
  304. if i in [0, 1, 2, 5, 6, 7]:
  305. y[y == 0] = np.nan # don't show zero loss values
  306. # y /= y[0] # normalize
  307. label = labels[fi] if len(labels) else f.stem
  308. ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
  309. ax[i].set_title(s[i])
  310. # if i in [5, 6, 7]: # share train and val loss y axes
  311. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  312. except Exception as e:
  313. print('Warning: Plotting error for %s; %s' % (f, e))
  314. fig.tight_layout()
  315. ax[1].legend()
  316. fig.savefig(Path(save_dir) / 'results.png', dpi=200)