Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

300 lines
11KB

  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Train and eval functions used in main.py
  4. Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
  5. """
  6. import math
  7. import os
  8. import sys
  9. from typing import Iterable
  10. import torch
  11. #print( os.path.abspath( os.path.dirname(__file__) ) )
  12. sys.path.append( os.path.abspath( os.path.dirname(__file__) ) )
  13. import util.misc as utils
  14. from util.misc import NestedTensor
  15. import numpy as np
  16. import time
  17. import torchvision.transforms as standard_transforms
  18. import cv2
  19. import PIL
  20. class DictToObject:
  21. def __init__(self, dictionary):
  22. for key, value in dictionary.items():
  23. if isinstance(value, dict):
  24. setattr(self, key, DictToObject(value))
  25. else:
  26. setattr(self, key, value)
  27. def letterImage(img,minShape,maxShape):
  28. iH,iW = img.shape[0:2]
  29. minH,minW = minShape[2:]
  30. maxH,maxW = maxShape[2:]
  31. flag=False
  32. if iH<minH or iW<minW:
  33. fy = iH/minH; fx = iW/minW; ff = min(fx,fy)
  34. newH,newW = int(iH/ff), int(iW/ff);flag=True
  35. if iH>maxH or iW>maxW:
  36. fy = iH/maxH; fx = iW/maxW; ff = max(fx,fy)
  37. newH,newW = int(iH/ff), int(iW/ff);flag=True
  38. if flag:
  39. assert minH<=newH and newH<= maxH , 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
  40. assert minW<=newW and newW<= maxW, 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
  41. return cv2.resize(img,(newW,newH))
  42. else:
  43. return img
  44. def postprocess(outputs,threshold=0.5):
  45. outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
  46. outputs_points = outputs['pred_points'][0]
  47. points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
  48. scores = outputs_scores[outputs_scores > threshold].detach().cpu().numpy().tolist()
  49. return points,scores
  50. def toOBBformat(points,scores,cls=0):
  51. outs = []
  52. for i in range(len(points)):
  53. pt,score = points[i],scores[i]
  54. pts4=[pt]*4
  55. ret = [ pts4,score,cls]
  56. outs.append(ret)
  57. return outs
  58. #[ [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score, cls ], [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score ,cls ],........ ]
  59. def preprocess(img,mean,std,minShape,maxShape):
  60. #img--numpy,(H,W,C)
  61. #输入-RGB格式,(C,H,W)
  62. if isinstance(img,PIL.Image.Image):
  63. img = np.array(img)
  64. img = letterImage(img,minShape,maxShape)
  65. height,width = img.shape[0:2]
  66. new_width = width // 128 * 128
  67. new_height = height // 128 * 128
  68. img = cv2.resize( img, (new_width, new_height) )
  69. img = img/255.
  70. tmpImg = np.zeros((new_height,new_width,3))
  71. tmpImg[:,:,0]=(img[:,:,0]-mean[0])/std[0]
  72. tmpImg[:,:,1]=(img[:,:,1]-mean[1])/std[1]
  73. tmpImg[:,:,2]=(img[:,:,2]-mean[2])/std[2]
  74. tmpImg = tmpImg.transpose((2,0,1)).astype(np.float32)# HWC->CHW
  75. #tmpImg = tmpImg[np.newaxis,:,:,:]#CHW->NCHW
  76. return tmpImg
  77. class DeNormalize(object):
  78. def __init__(self, mean, std):
  79. self.mean = mean
  80. self.std = std
  81. def __call__(self, tensor):
  82. for t, m, s in zip(tensor, self.mean, self.std):
  83. t.mul_(s).add_(m)
  84. return tensor
  85. # generate the reference points in grid layout
  86. def generate_anchor_points(stride=16, row=3, line=3):
  87. row_step = stride / row
  88. line_step = stride / line
  89. shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
  90. shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2
  91. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  92. anchor_points = np.vstack((
  93. shift_x.ravel(), shift_y.ravel()
  94. )).transpose()
  95. return anchor_points
  96. def shift(shape, stride, anchor_points):
  97. shift_x = (np.arange(0, shape[1]) + 0.5) * stride
  98. shift_y = (np.arange(0, shape[0]) + 0.5) * stride
  99. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  100. shifts = np.vstack((
  101. shift_x.ravel(), shift_y.ravel()
  102. )).transpose()
  103. A = anchor_points.shape[0]
  104. K = shifts.shape[0]
  105. all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
  106. all_anchor_points = all_anchor_points.reshape((K * A, 2))
  107. return all_anchor_points
  108. class AnchorPointsf(object):
  109. def __init__(self, pyramid_levels=[3,], strides=None, row=3, line=3,device='cpu'):
  110. if pyramid_levels is None:
  111. self.pyramid_levels = [3, 4, 5, 6, 7]
  112. else:
  113. self.pyramid_levels = pyramid_levels
  114. if strides is None:
  115. self.strides = [2 ** x for x in self.pyramid_levels]
  116. self.row = row
  117. self.line = line
  118. self.device = device
  119. def eval(self, image):
  120. image_shape = image.shape[2:]
  121. image_shape = np.array(image_shape)
  122. image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
  123. all_anchor_points = np.zeros((0, 2)).astype(np.float32)
  124. # get reference points for each level
  125. for idx, p in enumerate(self.pyramid_levels):
  126. anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line)
  127. shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points)
  128. all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)
  129. all_anchor_points = np.expand_dims(all_anchor_points, axis=0)
  130. # send reference points to device
  131. if torch.cuda.is_available() and self.device!='cpu':
  132. return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
  133. else:
  134. return torch.from_numpy(all_anchor_points.astype(np.float32))
  135. def vis(samples, targets, pred, vis_dir, des=None):
  136. '''
  137. samples -> tensor: [batch, 3, H, W]
  138. targets -> list of dict: [{'points':[], 'image_id': str}]
  139. pred -> list: [num_preds, 2]
  140. '''
  141. gts = [t['point'].tolist() for t in targets]
  142. pil_to_tensor = standard_transforms.ToTensor()
  143. restore_transform = standard_transforms.Compose([
  144. DeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  145. standard_transforms.ToPILImage()
  146. ])
  147. # draw one by one
  148. for idx in range(samples.shape[0]):
  149. sample = restore_transform(samples[idx])
  150. sample = pil_to_tensor(sample.convert('RGB')).numpy() * 255
  151. sample_gt = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
  152. sample_pred = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
  153. max_len = np.max(sample_gt.shape)
  154. size = 2
  155. # draw gt
  156. for t in gts[idx]:
  157. sample_gt = cv2.circle(sample_gt, (int(t[0]), int(t[1])), size, (0, 255, 0), -1)
  158. # draw predictions
  159. for p in pred[idx]:
  160. sample_pred = cv2.circle(sample_pred, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
  161. name = targets[idx]['image_id']
  162. # save the visualized images
  163. if des is not None:
  164. cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_gt.jpg'.format(int(name),
  165. des, len(gts[idx]), len(pred[idx]))), sample_gt)
  166. cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_pred.jpg'.format(int(name),
  167. des, len(gts[idx]), len(pred[idx]))), sample_pred)
  168. else:
  169. cv2.imwrite(
  170. os.path.join(vis_dir, '{}_gt_{}_pred_{}_gt.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
  171. sample_gt)
  172. cv2.imwrite(
  173. os.path.join(vis_dir, '{}_gt_{}_pred_{}_pred.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
  174. sample_pred)
  175. # the training routine
  176. def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
  177. data_loader: Iterable, optimizer: torch.optim.Optimizer,
  178. device: torch.device, epoch: int, max_norm: float = 0):
  179. model.train()
  180. criterion.train()
  181. metric_logger = utils.MetricLogger(delimiter=" ")
  182. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
  183. # iterate all training samples
  184. for samples, targets in data_loader:
  185. samples = samples.to(device)
  186. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  187. # forward
  188. outputs = model(samples)
  189. # calc the losses
  190. loss_dict = criterion(outputs, targets)
  191. weight_dict = criterion.weight_dict
  192. losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
  193. # reduce all losses
  194. loss_dict_reduced = utils.reduce_dict(loss_dict)
  195. loss_dict_reduced_unscaled = {f'{k}_unscaled': v
  196. for k, v in loss_dict_reduced.items()}
  197. loss_dict_reduced_scaled = {k: v * weight_dict[k]
  198. for k, v in loss_dict_reduced.items() if k in weight_dict}
  199. losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
  200. loss_value = losses_reduced_scaled.item()
  201. if not math.isfinite(loss_value):
  202. print("Loss is {}, stopping training".format(loss_value))
  203. print(loss_dict_reduced)
  204. sys.exit(1)
  205. # backward
  206. optimizer.zero_grad()
  207. losses.backward()
  208. if max_norm > 0:
  209. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  210. optimizer.step()
  211. # update logger
  212. metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
  213. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  214. # gather the stats from all processes
  215. metric_logger.synchronize_between_processes()
  216. print("Averaged stats:", metric_logger)
  217. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
  218. # the inference routine
  219. @torch.no_grad()
  220. def evaluate_crowd_no_overlap(model, data_loader, device, vis_dir=None):
  221. model.eval()
  222. metric_logger = utils.MetricLogger(delimiter=" ")
  223. metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
  224. # run inference on all images to calc MAE
  225. maes = []
  226. mses = []
  227. for samples, targets in data_loader:
  228. samples = samples.to(device)
  229. outputs = model(samples)
  230. outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
  231. outputs_points = outputs['pred_points'][0]
  232. gt_cnt = targets[0]['point'].shape[0]
  233. # 0.5 is used by default
  234. threshold = 0.5
  235. points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
  236. predict_cnt = int((outputs_scores > threshold).sum())
  237. # if specified, save the visualized images
  238. if vis_dir is not None:
  239. vis(samples, targets, [points], vis_dir)
  240. # accumulate MAE, MSE
  241. mae = abs(predict_cnt - gt_cnt)
  242. mse = (predict_cnt - gt_cnt) * (predict_cnt - gt_cnt)
  243. maes.append(float(mae))
  244. mses.append(float(mse))
  245. # calc MAE, MSE
  246. mae = np.mean(maes)
  247. mse = np.sqrt(np.mean(mses))
  248. return mae, mse