You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

426 satır
18KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Plotting utils
  4. """
  5. import math
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import pandas as pd
  13. import seaborn as sn
  14. import torch
  15. from PIL import Image, ImageDraw, ImageFont
  16. from utils.general import is_ascii, xyxy2xywh, xywh2xyxy
  17. from utils.metrics import fitness
  18. # Settings
  19. matplotlib.rc('font', **{'size': 11})
  20. matplotlib.use('Agg') # for writing to files only
  21. class Colors:
  22. # Ultralytics color palette https://ultralytics.com/
  23. def __init__(self):
  24. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  25. hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  26. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  27. self.palette = [self.hex2rgb('#' + c) for c in hex]
  28. self.n = len(self.palette)
  29. def __call__(self, i, bgr=False):
  30. c = self.palette[int(i) % self.n]
  31. return (c[2], c[1], c[0]) if bgr else c
  32. @staticmethod
  33. def hex2rgb(h): # rgb order (PIL)
  34. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  35. colors = Colors() # create instance for 'from utils.plots import colors'
  36. class Annotator:
  37. # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
  38. def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=True):
  39. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
  40. self.pil = pil
  41. if self.pil: # use PIL
  42. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  43. self.draw = ImageDraw.Draw(self.im)
  44. s = sum(self.im.size) / 2 # mean shape
  45. f = font_size or max(round(s * 0.035), 12)
  46. try:
  47. self.font = ImageFont.truetype(font, size=f)
  48. except Exception as e: # download if missing
  49. url = "https://ultralytics.com/assets/" + font
  50. print(f'Downloading {url} to {font}...')
  51. torch.hub.download_url_to_file(url, font)
  52. self.font = ImageFont.truetype(font, size=f)
  53. self.fh = self.font.getsize('a')[1] - 3 # font height
  54. else: # use cv2
  55. self.im = im
  56. s = sum(im.shape) / 2 # mean shape
  57. self.lw = line_width or max(round(s * 0.003), 2) # line width
  58. def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
  59. # Add one xyxy box to image with label
  60. if self.pil or not is_ascii(label):
  61. self.draw.rectangle(box, width=self.lw, outline=color) # box
  62. if label:
  63. w = self.font.getsize(label)[0] # text width
  64. self.draw.rectangle([box[0], box[1] - self.fh, box[0] + w + 1, box[1] + 1], fill=color)
  65. self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls')
  66. else: # cv2
  67. c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  68. cv2.rectangle(self.im, c1, c2, color, thickness=self.lw, lineType=cv2.LINE_AA)
  69. if label:
  70. tf = max(self.lw - 1, 1) # font thickness
  71. w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
  72. c2 = c1[0] + w, c1[1] - h - 3
  73. cv2.rectangle(self.im, c1, c2, color, -1, cv2.LINE_AA) # filled
  74. cv2.putText(self.im, label, (c1[0], c1[1] - 2), 0, self.lw / 3, txt_color, thickness=tf,
  75. lineType=cv2.LINE_AA)
  76. def rectangle(self, xy, fill=None, outline=None, width=1):
  77. # Add rectangle to image (PIL-only)
  78. self.draw.rectangle(xy, fill, outline, width)
  79. def text(self, xy, text, txt_color=(255, 255, 255)):
  80. # Add text to image (PIL-only)
  81. w, h = self.font.getsize(text) # text width, height
  82. self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)
  83. def result(self):
  84. # Return annotated image as array
  85. return np.asarray(self.im)
  86. def hist2d(x, y, n=100):
  87. # 2d histogram used in labels.png and evolve.png
  88. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  89. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  90. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  91. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  92. return np.log(hist[xidx, yidx])
  93. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  94. from scipy.signal import butter, filtfilt
  95. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  96. def butter_lowpass(cutoff, fs, order):
  97. nyq = 0.5 * fs
  98. normal_cutoff = cutoff / nyq
  99. return butter(order, normal_cutoff, btype='low', analog=False)
  100. b, a = butter_lowpass(cutoff, fs, order=order)
  101. return filtfilt(b, a, data) # forward-backward filter
  102. def output_to_target(output):
  103. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  104. targets = []
  105. for i, o in enumerate(output):
  106. for *box, conf, cls in o.cpu().numpy():
  107. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  108. return np.array(targets)
  109. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
  110. # Plot image grid with labels
  111. if isinstance(images, torch.Tensor):
  112. images = images.cpu().float().numpy()
  113. if isinstance(targets, torch.Tensor):
  114. targets = targets.cpu().numpy()
  115. if np.max(images[0]) <= 1:
  116. images *= 255.0 # de-normalise (optional)
  117. bs, _, h, w = images.shape # batch size, _, height, width
  118. bs = min(bs, max_subplots) # limit plot images
  119. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  120. # Build Image
  121. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  122. for i, im in enumerate(images):
  123. if i == max_subplots: # if last batch has fewer images than we expect
  124. break
  125. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  126. im = im.transpose(1, 2, 0)
  127. mosaic[y:y + h, x:x + w, :] = im
  128. # Resize (optional)
  129. scale = max_size / ns / max(h, w)
  130. if scale < 1:
  131. h = math.ceil(scale * h)
  132. w = math.ceil(scale * w)
  133. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  134. # Annotate
  135. fs = int((h + w) * ns * 0.01) # font size
  136. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs)
  137. for i in range(i + 1):
  138. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  139. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  140. if paths:
  141. annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  142. if len(targets) > 0:
  143. ti = targets[targets[:, 0] == i] # image targets
  144. boxes = xywh2xyxy(ti[:, 2:6]).T
  145. classes = ti[:, 1].astype('int')
  146. labels = ti.shape[1] == 6 # labels if no conf column
  147. conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
  148. if boxes.shape[1]:
  149. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  150. boxes[[0, 2]] *= w # scale to pixels
  151. boxes[[1, 3]] *= h
  152. elif scale < 1: # absolute coords need scale if image scales
  153. boxes *= scale
  154. boxes[[0, 2]] += x
  155. boxes[[1, 3]] += y
  156. for j, box in enumerate(boxes.T.tolist()):
  157. cls = classes[j]
  158. color = colors(cls)
  159. cls = names[cls] if names else cls
  160. if labels or conf[j] > 0.25: # 0.25 conf thresh
  161. label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
  162. annotator.box_label(box, label, color=color)
  163. annotator.im.save(fname) # save
  164. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  165. # Plot LR simulating training for full epochs
  166. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  167. y = []
  168. for _ in range(epochs):
  169. scheduler.step()
  170. y.append(optimizer.param_groups[0]['lr'])
  171. plt.plot(y, '.-', label='LR')
  172. plt.xlabel('epoch')
  173. plt.ylabel('LR')
  174. plt.grid()
  175. plt.xlim(0, epochs)
  176. plt.ylim(0)
  177. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  178. plt.close()
  179. def plot_val_txt(): # from utils.plots import *; plot_val()
  180. # Plot val.txt histograms
  181. x = np.loadtxt('val.txt', dtype=np.float32)
  182. box = xyxy2xywh(x[:, :4])
  183. cx, cy = box[:, 0], box[:, 1]
  184. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  185. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  186. ax.set_aspect('equal')
  187. plt.savefig('hist2d.png', dpi=300)
  188. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  189. ax[0].hist(cx, bins=600)
  190. ax[1].hist(cy, bins=600)
  191. plt.savefig('hist1d.png', dpi=200)
  192. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  193. # Plot targets.txt histograms
  194. x = np.loadtxt('targets.txt', dtype=np.float32).T
  195. s = ['x targets', 'y targets', 'width targets', 'height targets']
  196. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  197. ax = ax.ravel()
  198. for i in range(4):
  199. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  200. ax[i].legend()
  201. ax[i].set_title(s[i])
  202. plt.savefig('targets.jpg', dpi=200)
  203. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  204. # Plot study.txt generated by val.py
  205. plot2 = False # plot additional results
  206. if plot2:
  207. ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
  208. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  209. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  210. for f in sorted(Path(path).glob('study*.txt')):
  211. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  212. x = np.arange(y.shape[1]) if x is None else np.array(x)
  213. if plot2:
  214. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
  215. for i in range(7):
  216. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  217. ax[i].set_title(s[i])
  218. j = y[3].argmax() + 1
  219. ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  220. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  221. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  222. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  223. ax2.grid(alpha=0.2)
  224. ax2.set_yticks(np.arange(20, 60, 5))
  225. ax2.set_xlim(0, 57)
  226. ax2.set_ylim(30, 55)
  227. ax2.set_xlabel('GPU Speed (ms/img)')
  228. ax2.set_ylabel('COCO AP val')
  229. ax2.legend(loc='lower right')
  230. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  231. def plot_labels(labels, names=(), save_dir=Path('')):
  232. # plot dataset labels
  233. print('Plotting labels... ')
  234. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  235. nc = int(c.max() + 1) # number of classes
  236. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  237. # seaborn correlogram
  238. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  239. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  240. plt.close()
  241. # matplotlib labels
  242. matplotlib.use('svg') # faster
  243. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  244. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  245. # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
  246. ax[0].set_ylabel('instances')
  247. if 0 < len(names) < 30:
  248. ax[0].set_xticks(range(len(names)))
  249. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  250. else:
  251. ax[0].set_xlabel('classes')
  252. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  253. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  254. # rectangles
  255. labels[:, 1:3] = 0.5 # center
  256. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  257. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  258. for cls, *box in labels[:1000]:
  259. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  260. ax[1].imshow(img)
  261. ax[1].axis('off')
  262. for a in [0, 1, 2, 3]:
  263. for s in ['top', 'right', 'left', 'bottom']:
  264. ax[a].spines[s].set_visible(False)
  265. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  266. matplotlib.use('Agg')
  267. plt.close()
  268. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  269. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  270. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  271. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  272. files = list(Path(save_dir).glob('frames*.txt'))
  273. for fi, f in enumerate(files):
  274. try:
  275. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  276. n = results.shape[1] # number of rows
  277. x = np.arange(start, min(stop, n) if stop else n)
  278. results = results[:, x]
  279. t = (results[0] - results[0].min()) # set t0=0s
  280. results[0] = x
  281. for i, a in enumerate(ax):
  282. if i < len(results):
  283. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  284. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  285. a.set_title(s[i])
  286. a.set_xlabel('time (s)')
  287. # if fi == len(files) - 1:
  288. # a.set_ylim(bottom=0)
  289. for side in ['top', 'right']:
  290. a.spines[side].set_visible(False)
  291. else:
  292. a.remove()
  293. except Exception as e:
  294. print('Warning: Plotting error for %s; %s' % (f, e))
  295. ax[1].legend()
  296. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  297. def plot_evolve(evolve_csv=Path('path/to/evolve.csv')): # from utils.plots import *; plot_evolve()
  298. # Plot evolve.csv hyp evolution results
  299. data = pd.read_csv(evolve_csv)
  300. keys = [x.strip() for x in data.columns]
  301. x = data.values
  302. f = fitness(x)
  303. j = np.argmax(f) # max fitness index
  304. plt.figure(figsize=(10, 12), tight_layout=True)
  305. matplotlib.rc('font', **{'size': 8})
  306. for i, k in enumerate(keys[7:]):
  307. v = x[:, 7 + i]
  308. mu = v[j] # best single result
  309. plt.subplot(6, 5, i + 1)
  310. plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  311. plt.plot(mu, f.max(), 'k+', markersize=15)
  312. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  313. if i % 5 != 0:
  314. plt.yticks([])
  315. print('%15s: %.3g' % (k, mu))
  316. f = evolve_csv.with_suffix('.png') # filename
  317. plt.savefig(f, dpi=200)
  318. print(f'Saved {f}')
  319. def plot_results(file='path/to/results.csv', dir=''):
  320. # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
  321. save_dir = Path(file).parent if file else Path(dir)
  322. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  323. ax = ax.ravel()
  324. files = list(save_dir.glob('results*.csv'))
  325. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  326. for fi, f in enumerate(files):
  327. try:
  328. data = pd.read_csv(f)
  329. s = [x.strip() for x in data.columns]
  330. x = data.values[:, 0]
  331. for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
  332. y = data.values[:, j]
  333. # y[y == 0] = np.nan # don't show zero values
  334. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
  335. ax[i].set_title(s[j], fontsize=12)
  336. # if j in [8, 9, 10]: # share train and val loss y axes
  337. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  338. except Exception as e:
  339. print(f'Warning: Plotting error for {f}: {e}')
  340. ax[1].legend()
  341. fig.savefig(save_dir / 'results.png', dpi=200)
  342. def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
  343. """
  344. x: Features to be visualized
  345. module_type: Module type
  346. stage: Module stage within model
  347. n: Maximum number of feature maps to plot
  348. save_dir: Directory to save results
  349. """
  350. if 'Detect' not in module_type:
  351. batch, channels, height, width = x.shape # batch, channels, height, width
  352. if height > 1 and width > 1:
  353. f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  354. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  355. n = min(n, channels) # number of plots
  356. fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  357. ax = ax.ravel()
  358. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  359. for i in range(n):
  360. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  361. ax[i].axis('off')
  362. print(f'Saving {save_dir / f}... ({n}/{channels})')
  363. plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')