Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

448 lines
19KB

  1. # Plotting utils
  2. import glob
  3. import math
  4. import os
  5. from copy import copy
  6. from pathlib import Path
  7. import cv2
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import pandas as pd
  12. import seaborn as sn
  13. import torch
  14. import yaml
  15. from PIL import Image, ImageDraw, ImageFont
  16. from utils.general import xywh2xyxy, xyxy2xywh
  17. from utils.metrics import fitness
  18. # Settings
  19. matplotlib.rc('font', **{'size': 11})
  20. matplotlib.use('Agg') # for writing to files only
  21. class Colors:
  22. # Ultralytics color palette https://ultralytics.com/
  23. def __init__(self):
  24. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  25. hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  26. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  27. self.palette = [self.hex2rgb('#' + c) for c in hex]
  28. self.n = len(self.palette)
  29. def __call__(self, i, bgr=False):
  30. c = self.palette[int(i) % self.n]
  31. return (c[2], c[1], c[0]) if bgr else c
  32. @staticmethod
  33. def hex2rgb(h): # rgb order (PIL)
  34. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  35. colors = Colors() # create instance for 'from utils.plots import colors'
  36. def hist2d(x, y, n=100):
  37. # 2d histogram used in labels.png and evolve.png
  38. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  39. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  40. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  41. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  42. return np.log(hist[xidx, yidx])
  43. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  44. from scipy.signal import butter, filtfilt
  45. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  46. def butter_lowpass(cutoff, fs, order):
  47. nyq = 0.5 * fs
  48. normal_cutoff = cutoff / nyq
  49. return butter(order, normal_cutoff, btype='low', analog=False)
  50. b, a = butter_lowpass(cutoff, fs, order=order)
  51. return filtfilt(b, a, data) # forward-backward filter
  52. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  53. # Plots one bounding box on image 'im' using OpenCV
  54. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  55. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  56. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  57. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  58. if label:
  59. tf = max(tl - 1, 1) # font thickness
  60. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  61. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  62. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  63. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  64. def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
  65. # Plots one bounding box on image 'im' using PIL
  66. im = Image.fromarray(im)
  67. draw = ImageDraw.Draw(im)
  68. line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
  69. draw.rectangle(box, width=line_thickness, outline=color) # plot
  70. if label:
  71. font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
  72. txt_width, txt_height = font.getsize(label)
  73. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
  74. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  75. return np.asarray(im)
  76. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  77. # Compares the two methods for width-height anchor multiplication
  78. # https://github.com/ultralytics/yolov3/issues/168
  79. x = np.arange(-4.0, 4.0, .1)
  80. ya = np.exp(x)
  81. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  82. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  83. plt.plot(x, ya, '.-', label='YOLOv3')
  84. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  85. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  86. plt.xlim(left=-4, right=4)
  87. plt.ylim(bottom=0, top=6)
  88. plt.xlabel('input')
  89. plt.ylabel('output')
  90. plt.grid()
  91. plt.legend()
  92. fig.savefig('comparison.png', dpi=200)
  93. def output_to_target(output):
  94. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  95. targets = []
  96. for i, o in enumerate(output):
  97. for *box, conf, cls in o.cpu().numpy():
  98. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  99. return np.array(targets)
  100. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  101. # Plot image grid with labels
  102. if isinstance(images, torch.Tensor):
  103. images = images.cpu().float().numpy()
  104. if isinstance(targets, torch.Tensor):
  105. targets = targets.cpu().numpy()
  106. # un-normalise
  107. if np.max(images[0]) <= 1:
  108. images *= 255
  109. tl = 3 # line thickness
  110. tf = max(tl - 1, 1) # font thickness
  111. bs, _, h, w = images.shape # batch size, _, height, width
  112. bs = min(bs, max_subplots) # limit plot images
  113. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  114. # Check if we should resize
  115. scale_factor = max_size / max(h, w)
  116. if scale_factor < 1:
  117. h = math.ceil(scale_factor * h)
  118. w = math.ceil(scale_factor * w)
  119. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  120. for i, img in enumerate(images):
  121. if i == max_subplots: # if last batch has fewer images than we expect
  122. break
  123. block_x = int(w * (i // ns))
  124. block_y = int(h * (i % ns))
  125. img = img.transpose(1, 2, 0)
  126. if scale_factor < 1:
  127. img = cv2.resize(img, (w, h))
  128. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  129. if len(targets) > 0:
  130. image_targets = targets[targets[:, 0] == i]
  131. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  132. classes = image_targets[:, 1].astype('int')
  133. labels = image_targets.shape[1] == 6 # labels if no conf column
  134. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  135. if boxes.shape[1]:
  136. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  137. boxes[[0, 2]] *= w # scale to pixels
  138. boxes[[1, 3]] *= h
  139. elif scale_factor < 1: # absolute coords need scale if image scales
  140. boxes *= scale_factor
  141. boxes[[0, 2]] += block_x
  142. boxes[[1, 3]] += block_y
  143. for j, box in enumerate(boxes.T):
  144. cls = int(classes[j])
  145. color = colors(cls)
  146. cls = names[cls] if names else cls
  147. if labels or conf[j] > 0.25: # 0.25 conf thresh
  148. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  149. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  150. # Draw image filename labels
  151. if paths:
  152. label = Path(paths[i]).name[:40] # trim to 40 char
  153. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  154. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  155. lineType=cv2.LINE_AA)
  156. # Image border
  157. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  158. if fname:
  159. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  160. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  161. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  162. Image.fromarray(mosaic).save(fname) # PIL save
  163. return mosaic
  164. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  165. # Plot LR simulating training for full epochs
  166. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  167. y = []
  168. for _ in range(epochs):
  169. scheduler.step()
  170. y.append(optimizer.param_groups[0]['lr'])
  171. plt.plot(y, '.-', label='LR')
  172. plt.xlabel('epoch')
  173. plt.ylabel('LR')
  174. plt.grid()
  175. plt.xlim(0, epochs)
  176. plt.ylim(0)
  177. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  178. plt.close()
  179. def plot_test_txt(): # from utils.plots import *; plot_test()
  180. # Plot test.txt histograms
  181. x = np.loadtxt('test.txt', dtype=np.float32)
  182. box = xyxy2xywh(x[:, :4])
  183. cx, cy = box[:, 0], box[:, 1]
  184. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  185. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  186. ax.set_aspect('equal')
  187. plt.savefig('hist2d.png', dpi=300)
  188. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  189. ax[0].hist(cx, bins=600)
  190. ax[1].hist(cy, bins=600)
  191. plt.savefig('hist1d.png', dpi=200)
  192. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  193. # Plot targets.txt histograms
  194. x = np.loadtxt('targets.txt', dtype=np.float32).T
  195. s = ['x targets', 'y targets', 'width targets', 'height targets']
  196. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  197. ax = ax.ravel()
  198. for i in range(4):
  199. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  200. ax[i].legend()
  201. ax[i].set_title(s[i])
  202. plt.savefig('targets.jpg', dpi=200)
  203. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  204. # Plot study.txt generated by test.py
  205. plot2 = False # plot additional results
  206. if plot2:
  207. ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
  208. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  209. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  210. for f in sorted(Path(path).glob('study*.txt')):
  211. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  212. x = np.arange(y.shape[1]) if x is None else np.array(x)
  213. if plot2:
  214. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
  215. for i in range(7):
  216. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  217. ax[i].set_title(s[i])
  218. j = y[3].argmax() + 1
  219. ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  220. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  221. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  222. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  223. ax2.grid(alpha=0.2)
  224. ax2.set_yticks(np.arange(20, 60, 5))
  225. ax2.set_xlim(0, 57)
  226. ax2.set_ylim(30, 55)
  227. ax2.set_xlabel('GPU Speed (ms/img)')
  228. ax2.set_ylabel('COCO AP val')
  229. ax2.legend(loc='lower right')
  230. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  231. def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
  232. # plot dataset labels
  233. print('Plotting labels... ')
  234. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  235. nc = int(c.max() + 1) # number of classes
  236. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  237. # seaborn correlogram
  238. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  239. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  240. plt.close()
  241. # matplotlib labels
  242. matplotlib.use('svg') # faster
  243. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  244. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  245. # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
  246. ax[0].set_ylabel('instances')
  247. if 0 < len(names) < 30:
  248. ax[0].set_xticks(range(len(names)))
  249. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  250. else:
  251. ax[0].set_xlabel('classes')
  252. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  253. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  254. # rectangles
  255. labels[:, 1:3] = 0.5 # center
  256. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  257. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  258. for cls, *box in labels[:1000]:
  259. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  260. ax[1].imshow(img)
  261. ax[1].axis('off')
  262. for a in [0, 1, 2, 3]:
  263. for s in ['top', 'right', 'left', 'bottom']:
  264. ax[a].spines[s].set_visible(False)
  265. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  266. matplotlib.use('Agg')
  267. plt.close()
  268. # loggers
  269. for k, v in loggers.items() or {}:
  270. if k == 'wandb' and v:
  271. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
  272. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  273. # Plot hyperparameter evolution results in evolve.txt
  274. with open(yaml_file) as f:
  275. hyp = yaml.safe_load(f)
  276. x = np.loadtxt('evolve.txt', ndmin=2)
  277. f = fitness(x)
  278. # weights = (f - f.min()) ** 2 # for weighted results
  279. plt.figure(figsize=(10, 12), tight_layout=True)
  280. matplotlib.rc('font', **{'size': 8})
  281. for i, (k, v) in enumerate(hyp.items()):
  282. y = x[:, i + 7]
  283. # mu = (y * weights).sum() / weights.sum() # best weighted result
  284. mu = y[f.argmax()] # best single result
  285. plt.subplot(6, 5, i + 1)
  286. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  287. plt.plot(mu, f.max(), 'k+', markersize=15)
  288. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  289. if i % 5 != 0:
  290. plt.yticks([])
  291. print('%15s: %.3g' % (k, mu))
  292. plt.savefig('evolve.png', dpi=200)
  293. print('\nPlot saved as evolve.png')
  294. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  295. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  296. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  297. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  298. files = list(Path(save_dir).glob('frames*.txt'))
  299. for fi, f in enumerate(files):
  300. try:
  301. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  302. n = results.shape[1] # number of rows
  303. x = np.arange(start, min(stop, n) if stop else n)
  304. results = results[:, x]
  305. t = (results[0] - results[0].min()) # set t0=0s
  306. results[0] = x
  307. for i, a in enumerate(ax):
  308. if i < len(results):
  309. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  310. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  311. a.set_title(s[i])
  312. a.set_xlabel('time (s)')
  313. # if fi == len(files) - 1:
  314. # a.set_ylim(bottom=0)
  315. for side in ['top', 'right']:
  316. a.spines[side].set_visible(False)
  317. else:
  318. a.remove()
  319. except Exception as e:
  320. print('Warning: Plotting error for %s; %s' % (f, e))
  321. ax[1].legend()
  322. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  323. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  324. # Plot training 'results*.txt', overlaying train and val losses
  325. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  326. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  327. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  328. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  329. n = results.shape[1] # number of rows
  330. x = range(start, min(stop, n) if stop else n)
  331. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  332. ax = ax.ravel()
  333. for i in range(5):
  334. for j in [i, i + 5]:
  335. y = results[j, x]
  336. ax[i].plot(x, y, marker='.', label=s[j])
  337. # y_smooth = butter_lowpass_filtfilt(y)
  338. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  339. ax[i].set_title(t[i])
  340. ax[i].legend()
  341. ax[i].set_ylabel(f) if i == 0 else None # add filename
  342. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  343. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  344. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  345. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  346. ax = ax.ravel()
  347. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  348. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  349. if bucket:
  350. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  351. files = ['results%g.txt' % x for x in id]
  352. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  353. os.system(c)
  354. else:
  355. files = list(Path(save_dir).glob('results*.txt'))
  356. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  357. for fi, f in enumerate(files):
  358. try:
  359. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  360. n = results.shape[1] # number of rows
  361. x = range(start, min(stop, n) if stop else n)
  362. for i in range(10):
  363. y = results[i, x]
  364. if i in [0, 1, 2, 5, 6, 7]:
  365. y[y == 0] = np.nan # don't show zero loss values
  366. # y /= y[0] # normalize
  367. label = labels[fi] if len(labels) else f.stem
  368. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  369. ax[i].set_title(s[i])
  370. # if i in [5, 6, 7]: # share train and val loss y axes
  371. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  372. except Exception as e:
  373. print('Warning: Plotting error for %s; %s' % (f, e))
  374. ax[1].legend()
  375. fig.savefig(Path(save_dir) / 'results.png', dpi=200)