You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

642 lines
27KB

  1. # Plotting utils
  2. import glob
  3. import math
  4. import os,sys
  5. import random
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import pandas as pd
  13. import seaborn as sns
  14. import torch
  15. import yaml
  16. from PIL import Image, ImageDraw, ImageFont
  17. from scipy.signal import butter, filtfilt,savgol_filter
  18. from utils.general import xywh2xyxy, xyxy2xywh
  19. from utils.metrics import fitness
  20. # Settings
  21. matplotlib.rc('font', **{'size': 11})
  22. #matplotlib.use('Agg') # for writing to files only
  23. def smooth_outline(contours,p1,p2):
  24. arcontours=np.array(contours)
  25. coors_x=arcontours[0,:,0,0]
  26. coors_y=arcontours[0,:,0,1]
  27. coors_x_smooth= savgol_filter(coors_x,p1,p2)
  28. coors_y_smooth= savgol_filter(coors_y,p1,p2)
  29. arcontours[0,:,0,0] = coors_x_smooth
  30. arcontours[0,:,0,1] = coors_y_smooth
  31. return arcontours
  32. def smooth_outline_auto(contours):
  33. cnt = len(contours[0])
  34. p1 = int(cnt/12)*2+1
  35. p2 =3
  36. if p1<p2: p2 = p1-1
  37. return smooth_outline(contours,p1,p2)
  38. def get_websource(txtfile):
  39. with open(txtfile,'r') as fp:
  40. lines = fp.readlines()
  41. webs=[];ports=[];streamNames=[]
  42. for line in lines:
  43. try:
  44. sps = line.strip().split(' ')
  45. webs.append(sps[0])
  46. #rtmp://liveplay.yunhengzhizao.cn/live/demo_HD5M
  47. if 'rtmp' in sps[0]:
  48. name = sps[0].split('/')[4].split('_')[0]
  49. else:
  50. name = sps[0][-3:]
  51. ports.append(sps[1])
  52. streamNames.append(name)
  53. except:
  54. print('####format error : %s , in file:%s#####'%(line,txtfile))
  55. assert len(webs)>0
  56. return webs,ports,streamNames
  57. def get_label_array( color=None, label=None,outfontsize=None,fontpath="conf/platech.ttf"):
  58. # Plots one bounding box on image 'im' using PIL
  59. fontsize = outfontsize
  60. font = ImageFont.truetype(fontpath, fontsize,encoding='utf-8')
  61. txt_width, txt_height = font.getsize(label)
  62. im = np.zeros((txt_height,txt_width,3),dtype=np.uint8)
  63. im = Image.fromarray(im)
  64. draw = ImageDraw.Draw(im)
  65. draw.rectangle([0, 0 , txt_width, txt_height ], fill=tuple(color))
  66. draw.text(( 0 , -3 ), label, fill=(255, 255, 255), font=font)
  67. im_array = np.asarray(im)
  68. if outfontsize:
  69. scaley = outfontsize / txt_height
  70. im_array= cv2.resize(im_array,(0,0),fx = scaley ,fy =scaley)
  71. return im_array
  72. def get_label_arrays(labelnames,colors,outfontsize=40,fontpath="conf/platech.ttf"):
  73. label_arraylist = []
  74. if len(labelnames) > len(colors):
  75. print('#####labelnames cnt > colors cnt#####')
  76. for ii,labelname in enumerate(labelnames):
  77. color = colors[ii%20]
  78. label_arraylist.append(get_label_array(color=color,label=labelname,outfontsize=outfontsize,fontpath=fontpath))
  79. return label_arraylist
  80. def color_list():
  81. # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  82. def hex2rgb(h):
  83. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  84. return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
  85. def hist2d(x, y, n=100):
  86. # 2d histogram used in labels.png and evolve.png
  87. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  88. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  89. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  90. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  91. return np.log(hist[xidx, yidx])
  92. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  93. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  94. def butter_lowpass(cutoff, fs, order):
  95. nyq = 0.5 * fs
  96. normal_cutoff = cutoff / nyq
  97. return butter(order, normal_cutoff, btype='low', analog=False)
  98. b, a = butter_lowpass(cutoff, fs, order=order)
  99. return filtfilt(b, a, data) # forward-backward filter
  100. '''image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  101. pil_image = Image.fromarray(image)
  102. draw = ImageDraw.Draw(pil_image)
  103. font = ImageFont.truetype('./font/platech.ttf', 40, encoding='utf-8')
  104. for info in infos:
  105. detect = info['bndbox']
  106. text = ','.join(list(info['attributes'].values()))
  107. temp = -50
  108. if info['name'] == 'vehicle':
  109. temp = 20
  110. draw.text((detect[0], detect[1] + temp), text, (0, 255, 255), font=font)
  111. if 'scores' in info:
  112. draw.text((detect[0], detect[3]), info['scores'], (0, 255, 0), font=font)
  113. if 'pscore' in info:
  114. draw.text((detect[2], detect[3]), str(round(info['pscore'],3)), (0, 255, 0), font=font)
  115. image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
  116. for info in infos:
  117. detect = info['bndbox']
  118. cv2.rectangle(image, (detect[0], detect[1]), (detect[2], detect[3]), (0, 255, 0), 1, cv2.LINE_AA)
  119. return image'''
  120. '''def plot_one_box_PIL(x, im, color=None, label=None, line_thickness=3):
  121. # Plots one bounding box on image 'im' using OpenCV
  122. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  123. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  124. color = color or [random.randint(0, 255) for _ in range(3)]
  125. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  126. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  127. if label:
  128. tf = max(tl - 1, 1) # font thickness
  129. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  130. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  131. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  132. im = Image.fromarray(im)
  133. draw = ImageDraw.Draw(im)
  134. font = ImageFont.truetype('./font/platech.ttf', t_size, encoding='utf-8')
  135. draw.text((c1[0], c1[1] - 2), label, (0, 255, 0), font=font)
  136. #cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  137. return np.array(im) '''
  138. def plot_one_box(x, im, color=None, label=None, line_thickness=3):
  139. # Plots one bounding box on image 'im' using OpenCV
  140. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  141. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  142. color = color or [random.randint(0, 255) for _ in range(3)]
  143. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  144. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  145. if label:
  146. tf = max(tl - 1, 1) # font thickness
  147. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  148. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  149. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  150. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  151. def plot_one_box_PIL(box, im, color=None, label=None, line_thickness=None):
  152. # Plots one bounding box on image 'im' using PIL
  153. im = Image.fromarray(im)
  154. draw = ImageDraw.Draw(im)
  155. line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
  156. draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
  157. if label:
  158. fontsize = max(round(max(im.size) / 40), 12)
  159. font = ImageFont.truetype("../AIlib2/conf/platech.ttf", fontsize,encoding='utf-8')
  160. txt_width, txt_height = font.getsize(label)
  161. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
  162. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  163. im_array = np.asarray(im)
  164. return np.asarray(im)
  165. def draw_painting_joint(box,img,label_array,score=0.5,color=None,font={ 'line_thickness':None,'boxLine_thickness':None, 'fontSize':None},socre_location="leftTop"):
  166. #如果box[0]不是list or 元组,则box是[ (x0,y0),(x1,y1),(x2,y2),(x3,y3)]四点格式
  167. if isinstance(box[0], (list, tuple,np.ndarray ) ):
  168. ###先把中文类别字体赋值到img中
  169. lh, lw, lc = label_array.shape
  170. imh, imw, imc = img.shape
  171. if socre_location=='leftTop':
  172. x0 , y1 = box[0][0],box[0][1]
  173. elif socre_location=='leftBottom':
  174. x0,y1=box[3][0],box[3][1]
  175. else:
  176. print('plot.py line217 ,label_location:%s not implemented '%( socre_location ))
  177. sys.exit(0)
  178. x1 , y0 = x0 + lw , y1 - lh
  179. if y0<0:y0=0;y1=y0+lh
  180. if y1>imh: y1=imh;y0=y1-lh
  181. if x0<0:x0=0;x1=x0+lw
  182. if x1>imw:x1=imw;x0=x1-lw
  183. img[y0:y1,x0:x1,:] = label_array
  184. pts_cls=[(x0,y0),(x1,y1) ]
  185. #把四边形的框画上
  186. box_tl= font['boxLine_thickness'] or round(0.002 * (imh + imw) / 2) + 1
  187. cv2.polylines(img, [box], True,color , box_tl)
  188. ####把英文字符score画到类别旁边
  189. tl = font['line_thickness'] or round(0.002*(imh+imw)/2)+1#line/font thickness
  190. label = ' %.2f'%(score)
  191. tf = max(tl , 1) # font thickness
  192. fontScale = font['fontSize'] or tl * 0.33
  193. t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
  194. #if socre_location=='leftTop':
  195. p1,p2= (pts_cls[1][0], pts_cls[0][1]),(pts_cls[1][0]+t_size[0],pts_cls[1][1])
  196. cv2.rectangle(img, p1 , p2, color, -1, cv2.LINE_AA)
  197. p3 = pts_cls[1][0],pts_cls[1][1]-(lh-t_size[1])//2
  198. cv2.putText(img, label,p3, 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  199. return img
  200. else:####两点格式[x0,y0,x1,y1]
  201. try:
  202. box = [int(xx.cpu()) for xx in box]
  203. except:
  204. box=[ int(x) for x in box]
  205. ###先把中文类别字体赋值到img中
  206. lh, lw, lc = label_array.shape
  207. imh, imw, imc = img.shape
  208. if socre_location=='leftTop':
  209. x0 , y1 = box[0:2]
  210. elif socre_location=='leftBottom':
  211. x0,y1=box[0],box[3]
  212. else:
  213. print('plot.py line217 ,socre_location:%s not implemented '%( socre_location ))
  214. sys.exit(0)
  215. x1 , y0 = x0 + lw , y1 - lh
  216. if y0<0:y0=0;y1=y0+lh
  217. if y1>imh: y1=imh;y0=y1-lh
  218. if x0<0:x0=0;x1=x0+lw
  219. if x1>imw:x1=imw;x0=x1-lw
  220. img[y0:y1,x0:x1,:] = label_array
  221. ###把矩形框画上,指定颜色和线宽
  222. tl = font['line_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  223. box_tl= font['boxLine_thickness'] or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
  224. c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  225. cv2.rectangle(img, c1, c2, color, thickness=box_tl, lineType=cv2.LINE_AA)
  226. ###把英文字符score画到类别旁边
  227. label = ' %.2f'%(score)
  228. tf = max(tl , 1) # font thickness
  229. fontScale = font['fontSize'] or tl * 0.33
  230. t_size = cv2.getTextSize(label, 0, fontScale=fontScale , thickness=tf)[0]
  231. if socre_location=='leftTop':
  232. c2 = c1[0]+ lw + t_size[0], c1[1] - lh
  233. cv2.rectangle(img, (int(box[0])+lw,int(box[1])) , c2, color, -1, cv2.LINE_AA) # filled
  234. cv2.putText(img, label, (c1[0]+lw, c1[1] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  235. elif socre_location=='leftBottom':
  236. c2 = box[0]+ lw + t_size[0], box[3] - lh
  237. cv2.rectangle(img, (int(box[0])+lw,int(box[3])) , c2, color, -1, cv2.LINE_AA) # filled
  238. cv2.putText(img, label, ( box[0] + lw, box[3] - (lh-t_size[1])//2 ), 0, fontScale, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  239. #print('#####line224 fontScale:',fontScale,' thickness:',tf,' line_thickness:',font['line_thickness'],' boxLine thickness:',box_tl)
  240. return img
  241. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  242. # Compares the two methods for width-height anchor multiplication
  243. # https://github.com/ultralytics/yolov3/issues/168
  244. x = np.arange(-4.0, 4.0, .1)
  245. ya = np.exp(x)
  246. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  247. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  248. plt.plot(x, ya, '.-', label='YOLOv3')
  249. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  250. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  251. plt.xlim(left=-4, right=4)
  252. plt.ylim(bottom=0, top=6)
  253. plt.xlabel('input')
  254. plt.ylabel('output')
  255. plt.grid()
  256. plt.legend()
  257. fig.savefig('comparison.png', dpi=200)
  258. def output_to_target(output):
  259. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  260. targets = []
  261. for i, o in enumerate(output):
  262. for *box, conf, cls in o.cpu().numpy():
  263. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  264. return np.array(targets)
  265. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  266. # Plot image grid with labels
  267. if isinstance(images, torch.Tensor):
  268. images = images.cpu().float().numpy()
  269. if isinstance(targets, torch.Tensor):
  270. targets = targets.cpu().numpy()
  271. # un-normalise
  272. if np.max(images[0]) <= 1:
  273. images *= 255
  274. tl = 3 # line thickness
  275. tf = max(tl - 1, 1) # font thickness
  276. bs, _, h, w = images.shape # batch size, _, height, width
  277. bs = min(bs, max_subplots) # limit plot images
  278. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  279. # Check if we should resize
  280. scale_factor = max_size / max(h, w)
  281. if scale_factor < 1:
  282. h = math.ceil(scale_factor * h)
  283. w = math.ceil(scale_factor * w)
  284. colors = color_list() # list of colors
  285. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  286. for i, img in enumerate(images):
  287. if i == max_subplots: # if last batch has fewer images than we expect
  288. break
  289. block_x = int(w * (i // ns))
  290. block_y = int(h * (i % ns))
  291. img = img.transpose(1, 2, 0)
  292. if scale_factor < 1:
  293. img = cv2.resize(img, (w, h))
  294. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  295. if len(targets) > 0:
  296. image_targets = targets[targets[:, 0] == i]
  297. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  298. classes = image_targets[:, 1].astype('int')
  299. labels = image_targets.shape[1] == 6 # labels if no conf column
  300. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  301. if boxes.shape[1]:
  302. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  303. boxes[[0, 2]] *= w # scale to pixels
  304. boxes[[1, 3]] *= h
  305. elif scale_factor < 1: # absolute coords need scale if image scales
  306. boxes *= scale_factor
  307. boxes[[0, 2]] += block_x
  308. boxes[[1, 3]] += block_y
  309. for j, box in enumerate(boxes.T):
  310. cls = int(classes[j])
  311. color = colors[cls % len(colors)]
  312. cls = names[cls] if names else cls
  313. if labels or conf[j] > 0.25: # 0.25 conf thresh
  314. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  315. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  316. # Draw image filename labels
  317. if paths:
  318. label = Path(paths[i]).name[:40] # trim to 40 char
  319. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  320. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  321. lineType=cv2.LINE_AA)
  322. # Image border
  323. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  324. if fname:
  325. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  326. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  327. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  328. Image.fromarray(mosaic).save(fname) # PIL save
  329. return mosaic
  330. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  331. # Plot LR simulating training for full epochs
  332. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  333. y = []
  334. for _ in range(epochs):
  335. scheduler.step()
  336. y.append(optimizer.param_groups[0]['lr'])
  337. plt.plot(y, '.-', label='LR')
  338. plt.xlabel('epoch')
  339. plt.ylabel('LR')
  340. plt.grid()
  341. plt.xlim(0, epochs)
  342. plt.ylim(0)
  343. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  344. plt.close()
  345. def plot_test_txt(): # from utils.plots import *; plot_test()
  346. # Plot test.txt histograms
  347. x = np.loadtxt('test.txt', dtype=np.float32)
  348. box = xyxy2xywh(x[:, :4])
  349. cx, cy = box[:, 0], box[:, 1]
  350. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  351. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  352. ax.set_aspect('equal')
  353. plt.savefig('hist2d.png', dpi=300)
  354. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  355. ax[0].hist(cx, bins=600)
  356. ax[1].hist(cy, bins=600)
  357. plt.savefig('hist1d.png', dpi=200)
  358. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  359. # Plot targets.txt histograms
  360. x = np.loadtxt('targets.txt', dtype=np.float32).T
  361. s = ['x targets', 'y targets', 'width targets', 'height targets']
  362. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  363. ax = ax.ravel()
  364. for i in range(4):
  365. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  366. ax[i].legend()
  367. ax[i].set_title(s[i])
  368. plt.savefig('targets.jpg', dpi=200)
  369. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  370. # Plot study.txt generated by test.py
  371. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  372. # ax = ax.ravel()
  373. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  374. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  375. for f in sorted(Path(path).glob('study*.txt')):
  376. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  377. x = np.arange(y.shape[1]) if x is None else np.array(x)
  378. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  379. # for i in range(7):
  380. # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  381. # ax[i].set_title(s[i])
  382. j = y[3].argmax() + 1
  383. ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  384. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  385. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  386. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  387. ax2.grid(alpha=0.2)
  388. ax2.set_yticks(np.arange(20, 60, 5))
  389. ax2.set_xlim(0, 57)
  390. ax2.set_ylim(30, 55)
  391. ax2.set_xlabel('GPU Speed (ms/img)')
  392. ax2.set_ylabel('COCO AP val')
  393. ax2.legend(loc='lower right')
  394. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  395. def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
  396. # plot dataset labels
  397. print('Plotting labels... ')
  398. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  399. nc = int(c.max() + 1) # number of classes
  400. colors = color_list()
  401. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  402. # seaborn correlogram
  403. sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  404. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  405. plt.close()
  406. # matplotlib labels
  407. matplotlib.use('svg') # faster
  408. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  409. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  410. ax[0].set_ylabel('instances')
  411. if 0 < len(names) < 30:
  412. ax[0].set_xticks(range(len(names)))
  413. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  414. else:
  415. ax[0].set_xlabel('classes')
  416. sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  417. sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  418. # rectangles
  419. labels[:, 1:3] = 0.5 # center
  420. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  421. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  422. for cls, *box in labels[:1000]:
  423. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
  424. ax[1].imshow(img)
  425. ax[1].axis('off')
  426. for a in [0, 1, 2, 3]:
  427. for s in ['top', 'right', 'left', 'bottom']:
  428. ax[a].spines[s].set_visible(False)
  429. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  430. matplotlib.use('Agg')
  431. plt.close()
  432. # loggers
  433. for k, v in loggers.items() or {}:
  434. if k == 'wandb' and v:
  435. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
  436. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  437. # Plot hyperparameter evolution results in evolve.txt
  438. with open(yaml_file) as f:
  439. hyp = yaml.load(f, Loader=yaml.SafeLoader)
  440. x = np.loadtxt('evolve.txt', ndmin=2)
  441. f = fitness(x)
  442. # weights = (f - f.min()) ** 2 # for weighted results
  443. plt.figure(figsize=(10, 12), tight_layout=True)
  444. matplotlib.rc('font', **{'size': 8})
  445. for i, (k, v) in enumerate(hyp.items()):
  446. y = x[:, i + 7]
  447. # mu = (y * weights).sum() / weights.sum() # best weighted result
  448. mu = y[f.argmax()] # best single result
  449. plt.subplot(6, 5, i + 1)
  450. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  451. plt.plot(mu, f.max(), 'k+', markersize=15)
  452. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  453. if i % 5 != 0:
  454. plt.yticks([])
  455. print('%15s: %.3g' % (k, mu))
  456. plt.savefig('evolve.png', dpi=200)
  457. print('\nPlot saved as evolve.png')
  458. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  459. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  460. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  461. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  462. files = list(Path(save_dir).glob('frames*.txt'))
  463. for fi, f in enumerate(files):
  464. try:
  465. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  466. n = results.shape[1] # number of rows
  467. x = np.arange(start, min(stop, n) if stop else n)
  468. results = results[:, x]
  469. t = (results[0] - results[0].min()) # set t0=0s
  470. results[0] = x
  471. for i, a in enumerate(ax):
  472. if i < len(results):
  473. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  474. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  475. a.set_title(s[i])
  476. a.set_xlabel('time (s)')
  477. # if fi == len(files) - 1:
  478. # a.set_ylim(bottom=0)
  479. for side in ['top', 'right']:
  480. a.spines[side].set_visible(False)
  481. else:
  482. a.remove()
  483. except Exception as e:
  484. print('Warning: Plotting error for %s; %s' % (f, e))
  485. ax[1].legend()
  486. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  487. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  488. # Plot training 'results*.txt', overlaying train and val losses
  489. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  490. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  491. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  492. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  493. n = results.shape[1] # number of rows
  494. x = range(start, min(stop, n) if stop else n)
  495. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  496. ax = ax.ravel()
  497. for i in range(5):
  498. for j in [i, i + 5]:
  499. y = results[j, x]
  500. ax[i].plot(x, y, marker='.', label=s[j])
  501. # y_smooth = butter_lowpass_filtfilt(y)
  502. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  503. ax[i].set_title(t[i])
  504. ax[i].legend()
  505. ax[i].set_ylabel(f) if i == 0 else None # add filename
  506. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  507. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  508. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  509. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  510. ax = ax.ravel()
  511. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  512. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  513. if bucket:
  514. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  515. files = ['results%g.txt' % x for x in id]
  516. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  517. os.system(c)
  518. else:
  519. files = list(Path(save_dir).glob('results*.txt'))
  520. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  521. for fi, f in enumerate(files):
  522. try:
  523. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  524. n = results.shape[1] # number of rows
  525. x = range(start, min(stop, n) if stop else n)
  526. for i in range(10):
  527. y = results[i, x]
  528. if i in [0, 1, 2, 5, 6, 7]:
  529. y[y == 0] = np.nan # don't show zero loss values
  530. # y /= y[0] # normalize
  531. label = labels[fi] if len(labels) else f.stem
  532. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  533. ax[i].set_title(s[i])
  534. # if i in [5, 6, 7]: # share train and val loss y axes
  535. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  536. except Exception as e:
  537. print('Warning: Plotting error for %s; %s' % (f, e))
  538. ax[1].legend()
  539. fig.savefig(Path(save_dir) / 'results.png', dpi=200)