TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

plots.py 18KB

8 kuukautta sitten
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # Plotting utils
  2. from copy import copy
  3. from pathlib import Path
  4. import cv2
  5. import math
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import pandas as pd
  10. import seaborn as sn
  11. import torch
  12. import yaml
  13. from PIL import Image, ImageDraw, ImageFont
  14. from utils.general import xywh2xyxy, xyxy2xywh
  15. from utils.metrics import fitness
  16. # Settings
  17. matplotlib.rc('font', **{'size': 11})
  18. matplotlib.use('Agg') # for writing to files only
  19. class Colors:
  20. # Ultralytics color palette https://ultralytics.com/
  21. def __init__(self):
  22. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  23. hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  24. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  25. self.palette = [self.hex2rgb('#' + c) for c in hex]
  26. self.n = len(self.palette)
  27. def __call__(self, i, bgr=False):
  28. c = self.palette[int(i) % self.n]
  29. return (c[2], c[1], c[0]) if bgr else c
  30. @staticmethod
  31. def hex2rgb(h): # rgb order (PIL)
  32. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  33. colors = Colors() # create instance for 'from utils.plots import colors'
  34. def hist2d(x, y, n=100):
  35. # 2d histogram used in labels.png and evolve.png
  36. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  37. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  38. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  39. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  40. return np.log(hist[xidx, yidx])
  41. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  42. from scipy.signal import butter, filtfilt
  43. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  44. def butter_lowpass(cutoff, fs, order):
  45. nyq = 0.5 * fs
  46. normal_cutoff = cutoff / nyq
  47. return butter(order, normal_cutoff, btype='low', analog=False)
  48. b, a = butter_lowpass(cutoff, fs, order=order)
  49. return filtfilt(b, a, data) # forward-backward filter
  50. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  51. # Plots one bounding box on image 'im' using OpenCV
  52. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  53. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  54. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  55. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  56. if label:
  57. tf = max(tl - 1, 1) # font thickness
  58. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  59. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  60. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  61. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  62. def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
  63. # Plots one bounding box on image 'im' using PIL
  64. im = Image.fromarray(im)
  65. draw = ImageDraw.Draw(im)
  66. line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
  67. draw.rectangle(box, width=line_thickness, outline=color) # plot
  68. if label:
  69. font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
  70. txt_width, txt_height = font.getsize(label)
  71. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
  72. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  73. return np.asarray(im)
  74. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  75. # Compares the two methods for width-height anchor multiplication
  76. # https://github.com/ultralytics/yolov3/issues/168
  77. x = np.arange(-4.0, 4.0, .1)
  78. ya = np.exp(x)
  79. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  80. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  81. plt.plot(x, ya, '.-', label='YOLOv3')
  82. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  83. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  84. plt.xlim(left=-4, right=4)
  85. plt.ylim(bottom=0, top=6)
  86. plt.xlabel('input')
  87. plt.ylabel('output')
  88. plt.grid()
  89. plt.legend()
  90. fig.savefig('comparison.png', dpi=200)
  91. def output_to_target(output):
  92. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  93. targets = []
  94. for i, o in enumerate(output):
  95. for *box, conf, cls in o.cpu().numpy():
  96. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  97. return np.array(targets)
  98. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  99. # Plot image grid with labels
  100. if isinstance(images, torch.Tensor):
  101. images = images.cpu().float().numpy()
  102. if isinstance(targets, torch.Tensor):
  103. targets = targets.cpu().numpy()
  104. # un-normalise
  105. if np.max(images[0]) <= 1:
  106. images *= 255
  107. tl = 3 # line thickness
  108. tf = max(tl - 1, 1) # font thickness
  109. bs, _, h, w = images.shape # batch size, _, height, width
  110. bs = min(bs, max_subplots) # limit plot images
  111. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  112. # Check if we should resize
  113. scale_factor = max_size / max(h, w)
  114. if scale_factor < 1:
  115. h = math.ceil(scale_factor * h)
  116. w = math.ceil(scale_factor * w)
  117. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  118. for i, img in enumerate(images):
  119. if i == max_subplots: # if last batch has fewer images than we expect
  120. break
  121. block_x = int(w * (i // ns))
  122. block_y = int(h * (i % ns))
  123. img = img.transpose(1, 2, 0)
  124. if scale_factor < 1:
  125. img = cv2.resize(img, (w, h))
  126. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  127. if len(targets) > 0:
  128. image_targets = targets[targets[:, 0] == i]
  129. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  130. classes = image_targets[:, 1].astype('int')
  131. labels = image_targets.shape[1] == 6 # labels if no conf column
  132. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  133. if boxes.shape[1]:
  134. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  135. boxes[[0, 2]] *= w # scale to pixels
  136. boxes[[1, 3]] *= h
  137. elif scale_factor < 1: # absolute coords need scale if image scales
  138. boxes *= scale_factor
  139. boxes[[0, 2]] += block_x
  140. boxes[[1, 3]] += block_y
  141. for j, box in enumerate(boxes.T):
  142. cls = int(classes[j])
  143. color = colors(cls)
  144. cls = names[cls] if names else cls
  145. if labels or conf[j] > 0.25: # 0.25 conf thresh
  146. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  147. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  148. # Draw image filename labels
  149. if paths:
  150. label = Path(paths[i]).name[:40] # trim to 40 char
  151. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  152. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  153. lineType=cv2.LINE_AA)
  154. # Image border
  155. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  156. if fname:
  157. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  158. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  159. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  160. Image.fromarray(mosaic).save(fname) # PIL save
  161. return mosaic
  162. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  163. # Plot LR simulating training for full epochs
  164. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  165. y = []
  166. for _ in range(epochs):
  167. scheduler.step()
  168. y.append(optimizer.param_groups[0]['lr'])
  169. plt.plot(y, '.-', label='LR')
  170. plt.xlabel('epoch')
  171. plt.ylabel('LR')
  172. plt.grid()
  173. plt.xlim(0, epochs)
  174. plt.ylim(0)
  175. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  176. plt.close()
  177. def plot_val_txt(): # from utils.plots import *; plot_val()
  178. # Plot val.txt histograms
  179. x = np.loadtxt('val.txt', dtype=np.float32)
  180. box = xyxy2xywh(x[:, :4])
  181. cx, cy = box[:, 0], box[:, 1]
  182. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  183. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  184. ax.set_aspect('equal')
  185. plt.savefig('hist2d.png', dpi=300)
  186. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  187. ax[0].hist(cx, bins=600)
  188. ax[1].hist(cy, bins=600)
  189. plt.savefig('hist1d.png', dpi=200)
  190. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  191. # Plot targets.txt histograms
  192. x = np.loadtxt('targets.txt', dtype=np.float32).T
  193. s = ['x targets', 'y targets', 'width targets', 'height targets']
  194. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  195. ax = ax.ravel()
  196. for i in range(4):
  197. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  198. ax[i].legend()
  199. ax[i].set_title(s[i])
  200. plt.savefig('targets.jpg', dpi=200)
  201. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  202. # Plot study.txt generated by val.py
  203. plot2 = False # plot additional results
  204. if plot2:
  205. ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
  206. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  207. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  208. for f in sorted(Path(path).glob('study*.txt')):
  209. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  210. x = np.arange(y.shape[1]) if x is None else np.array(x)
  211. if plot2:
  212. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
  213. for i in range(7):
  214. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  215. ax[i].set_title(s[i])
  216. j = y[3].argmax() + 1
  217. ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  218. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  219. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  220. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  221. ax2.grid(alpha=0.2)
  222. ax2.set_yticks(np.arange(20, 60, 5))
  223. ax2.set_xlim(0, 57)
  224. ax2.set_ylim(30, 55)
  225. ax2.set_xlabel('GPU Speed (ms/img)')
  226. ax2.set_ylabel('COCO AP val')
  227. ax2.legend(loc='lower right')
  228. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  229. def plot_labels(labels, names=(), save_dir=Path('')):
  230. # plot dataset labels
  231. print('Plotting labels... ')
  232. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  233. nc = int(c.max() + 1) # number of classes
  234. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  235. # seaborn correlogram
  236. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  237. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  238. plt.close()
  239. # matplotlib labels
  240. matplotlib.use('svg') # faster
  241. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  242. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  243. # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
  244. ax[0].set_ylabel('instances')
  245. if 0 < len(names) < 30:
  246. ax[0].set_xticks(range(len(names)))
  247. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  248. else:
  249. ax[0].set_xlabel('classes')
  250. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  251. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  252. # rectangles
  253. labels[:, 1:3] = 0.5 # center
  254. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  255. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  256. for cls, *box in labels[:1000]:
  257. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  258. ax[1].imshow(img)
  259. ax[1].axis('off')
  260. for a in [0, 1, 2, 3]:
  261. for s in ['top', 'right', 'left', 'bottom']:
  262. ax[a].spines[s].set_visible(False)
  263. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  264. matplotlib.use('Agg')
  265. plt.close()
  266. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  267. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  268. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  269. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  270. files = list(Path(save_dir).glob('frames*.txt'))
  271. for fi, f in enumerate(files):
  272. try:
  273. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  274. n = results.shape[1] # number of rows
  275. x = np.arange(start, min(stop, n) if stop else n)
  276. results = results[:, x]
  277. t = (results[0] - results[0].min()) # set t0=0s
  278. results[0] = x
  279. for i, a in enumerate(ax):
  280. if i < len(results):
  281. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  282. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  283. a.set_title(s[i])
  284. a.set_xlabel('time (s)')
  285. # if fi == len(files) - 1:
  286. # a.set_ylim(bottom=0)
  287. for side in ['top', 'right']:
  288. a.spines[side].set_visible(False)
  289. else:
  290. a.remove()
  291. except Exception as e:
  292. print('Warning: Plotting error for %s; %s' % (f, e))
  293. ax[1].legend()
  294. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  295. def plot_evolve(evolve_csv=Path('path/to/evolve.csv')): # from utils.plots import *; plot_evolve()
  296. # Plot evolve.csv hyp evolution results
  297. data = pd.read_csv(evolve_csv)
  298. keys = [x.strip() for x in data.columns]
  299. x = data.values
  300. f = fitness(x)
  301. j = np.argmax(f) # max fitness index
  302. plt.figure(figsize=(10, 12), tight_layout=True)
  303. matplotlib.rc('font', **{'size': 8})
  304. for i, k in enumerate(keys[7:]):
  305. v = x[:, 7 + i]
  306. mu = v[j] # best single result
  307. plt.subplot(6, 5, i + 1)
  308. plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  309. plt.plot(mu, f.max(), 'k+', markersize=15)
  310. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  311. if i % 5 != 0:
  312. plt.yticks([])
  313. print('%15s: %.3g' % (k, mu))
  314. f = evolve_csv.with_suffix('.png') # filename
  315. plt.savefig(f, dpi=200)
  316. print(f'Saved {f}')
  317. def plot_results(file='path/to/results.csv', dir=''):
  318. # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
  319. save_dir = Path(file).parent if file else Path(dir)
  320. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  321. ax = ax.ravel()
  322. files = list(save_dir.glob('results*.csv'))
  323. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  324. for fi, f in enumerate(files):
  325. try:
  326. data = pd.read_csv(f)
  327. s = [x.strip() for x in data.columns]
  328. x = data.values[:, 0]
  329. for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
  330. y = data.values[:, j]
  331. # y[y == 0] = np.nan # don't show zero values
  332. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
  333. ax[i].set_title(s[j], fontsize=12)
  334. # if j in [8, 9, 10]: # share train and val loss y axes
  335. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  336. except Exception as e:
  337. print(f'Warning: Plotting error for {f}: {e}')
  338. ax[1].legend()
  339. fig.savefig(save_dir / 'results.png', dpi=200)
  340. def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
  341. """
  342. x: Features to be visualized
  343. module_type: Module type
  344. stage: Module stage within model
  345. n: Maximum number of feature maps to plot
  346. save_dir: Directory to save results
  347. """
  348. if 'Detect' not in module_type:
  349. batch, channels, height, width = x.shape # batch, channels, height, width
  350. if height > 1 and width > 1:
  351. f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  352. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  353. n = min(n, channels) # number of plots
  354. fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  355. ax = ax.ravel()
  356. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  357. for i in range(n):
  358. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  359. ax[i].axis('off')
  360. print(f'Saving {save_dir / f}... ({n}/{channels})')
  361. plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')