You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

518 satır
16KB

  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Misc functions, including distributed helpers.
  4. Mostly copy-paste from torchvision references.
  5. """
  6. import os
  7. import subprocess
  8. import time
  9. from collections import defaultdict, deque
  10. import datetime
  11. import pickle
  12. from typing import Optional, List
  13. import torch
  14. import torch.distributed as dist
  15. from torch import Tensor
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from torch.autograd import Variable
  19. # needed due to empty tensor bug in pytorch and torchvision 0.5
  20. import torchvision
  21. # if float(torchvision.__version__[:3]) < 0.7:
  22. # from torchvision.ops import _new_empty_tensor
  23. # from torchvision.ops.misc import _output_size
  24. class SmoothedValue(object):
  25. """Track a series of values and provide access to smoothed values over a
  26. window or the global series average.
  27. """
  28. def __init__(self, window_size=20, fmt=None):
  29. if fmt is None:
  30. fmt = "{median:.4f} ({global_avg:.4f})"
  31. self.deque = deque(maxlen=window_size)
  32. self.total = 0.0
  33. self.count = 0
  34. self.fmt = fmt
  35. def update(self, value, n=1):
  36. self.deque.append(value)
  37. self.count += n
  38. self.total += value * n
  39. def synchronize_between_processes(self):
  40. """
  41. Warning: does not synchronize the deque!
  42. """
  43. if not is_dist_avail_and_initialized():
  44. return
  45. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  46. dist.barrier()
  47. dist.all_reduce(t)
  48. t = t.tolist()
  49. self.count = int(t[0])
  50. self.total = t[1]
  51. @property
  52. def median(self):
  53. d = torch.tensor(list(self.deque))
  54. return d.median().item()
  55. @property
  56. def avg(self):
  57. d = torch.tensor(list(self.deque), dtype=torch.float32)
  58. return d.mean().item()
  59. @property
  60. def global_avg(self):
  61. return self.total / self.count
  62. @property
  63. def max(self):
  64. return max(self.deque)
  65. @property
  66. def value(self):
  67. return self.deque[-1]
  68. def __str__(self):
  69. return self.fmt.format(
  70. median=self.median,
  71. avg=self.avg,
  72. global_avg=self.global_avg,
  73. max=self.max,
  74. value=self.value)
  75. def all_gather(data):
  76. """
  77. Run all_gather on arbitrary picklable data (not necessarily tensors)
  78. Args:
  79. data: any picklable object
  80. Returns:
  81. list[data]: list of data gathered from each rank
  82. """
  83. world_size = get_world_size()
  84. if world_size == 1:
  85. return [data]
  86. # serialized to a Tensor
  87. buffer = pickle.dumps(data)
  88. storage = torch.ByteStorage.from_buffer(buffer)
  89. tensor = torch.ByteTensor(storage).to("cuda")
  90. # obtain Tensor size of each rank
  91. local_size = torch.tensor([tensor.numel()], device="cuda")
  92. size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
  93. dist.all_gather(size_list, local_size)
  94. size_list = [int(size.item()) for size in size_list]
  95. max_size = max(size_list)
  96. # receiving Tensor from all ranks
  97. # we pad the tensor because torch all_gather does not support
  98. # gathering tensors of different shapes
  99. tensor_list = []
  100. for _ in size_list:
  101. tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
  102. if local_size != max_size:
  103. padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
  104. tensor = torch.cat((tensor, padding), dim=0)
  105. dist.all_gather(tensor_list, tensor)
  106. data_list = []
  107. for size, tensor in zip(size_list, tensor_list):
  108. buffer = tensor.cpu().numpy().tobytes()[:size]
  109. data_list.append(pickle.loads(buffer))
  110. return data_list
  111. def reduce_dict(input_dict, average=True):
  112. """
  113. Args:
  114. input_dict (dict): all the values will be reduced
  115. average (bool): whether to do average or sum
  116. Reduce the values in the dictionary from all processes so that all processes
  117. have the averaged results. Returns a dict with the same fields as
  118. input_dict, after reduction.
  119. """
  120. world_size = get_world_size()
  121. if world_size < 2:
  122. return input_dict
  123. with torch.no_grad():
  124. names = []
  125. values = []
  126. # sort the keys so that they are consistent across processes
  127. for k in sorted(input_dict.keys()):
  128. names.append(k)
  129. values.append(input_dict[k])
  130. values = torch.stack(values, dim=0)
  131. dist.all_reduce(values)
  132. if average:
  133. values /= world_size
  134. reduced_dict = {k: v for k, v in zip(names, values)}
  135. return reduced_dict
  136. class MetricLogger(object):
  137. def __init__(self, delimiter="\t"):
  138. self.meters = defaultdict(SmoothedValue)
  139. self.delimiter = delimiter
  140. def update(self, **kwargs):
  141. for k, v in kwargs.items():
  142. if isinstance(v, torch.Tensor):
  143. v = v.item()
  144. assert isinstance(v, (float, int))
  145. self.meters[k].update(v)
  146. def __getattr__(self, attr):
  147. if attr in self.meters:
  148. return self.meters[attr]
  149. if attr in self.__dict__:
  150. return self.__dict__[attr]
  151. raise AttributeError("'{}' object has no attribute '{}'".format(
  152. type(self).__name__, attr))
  153. def __str__(self):
  154. loss_str = []
  155. for name, meter in self.meters.items():
  156. loss_str.append(
  157. "{}: {}".format(name, str(meter))
  158. )
  159. return self.delimiter.join(loss_str)
  160. def synchronize_between_processes(self):
  161. for meter in self.meters.values():
  162. meter.synchronize_between_processes()
  163. def add_meter(self, name, meter):
  164. self.meters[name] = meter
  165. def log_every(self, iterable, print_freq, header=None):
  166. i = 0
  167. if not header:
  168. header = ''
  169. start_time = time.time()
  170. end = time.time()
  171. iter_time = SmoothedValue(fmt='{avg:.4f}')
  172. data_time = SmoothedValue(fmt='{avg:.4f}')
  173. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  174. if torch.cuda.is_available():
  175. log_msg = self.delimiter.join([
  176. header,
  177. '[{0' + space_fmt + '}/{1}]',
  178. 'eta: {eta}',
  179. '{meters}',
  180. 'time: {time}',
  181. 'data: {data}',
  182. 'max mem: {memory:.0f}'
  183. ])
  184. else:
  185. log_msg = self.delimiter.join([
  186. header,
  187. '[{0' + space_fmt + '}/{1}]',
  188. 'eta: {eta}',
  189. '{meters}',
  190. 'time: {time}',
  191. 'data: {data}'
  192. ])
  193. MB = 1024.0 * 1024.0
  194. for obj in iterable:
  195. data_time.update(time.time() - end)
  196. yield obj
  197. iter_time.update(time.time() - end)
  198. if i % print_freq == 0 or i == len(iterable) - 1:
  199. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  200. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  201. if torch.cuda.is_available():
  202. print(log_msg.format(
  203. i, len(iterable), eta=eta_string,
  204. meters=str(self),
  205. time=str(iter_time), data=str(data_time),
  206. memory=torch.cuda.max_memory_allocated() / MB))
  207. else:
  208. print(log_msg.format(
  209. i, len(iterable), eta=eta_string,
  210. meters=str(self),
  211. time=str(iter_time), data=str(data_time)))
  212. i += 1
  213. end = time.time()
  214. total_time = time.time() - start_time
  215. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  216. print('{} Total time: {} ({:.4f} s / it)'.format(
  217. header, total_time_str, total_time / len(iterable)))
  218. def get_sha():
  219. cwd = os.path.dirname(os.path.abspath(__file__))
  220. def _run(command):
  221. return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
  222. sha = 'N/A'
  223. diff = "clean"
  224. branch = 'N/A'
  225. try:
  226. sha = _run(['git', 'rev-parse', 'HEAD'])
  227. subprocess.check_output(['git', 'diff'], cwd=cwd)
  228. diff = _run(['git', 'diff-index', 'HEAD'])
  229. diff = "has uncommited changes" if diff else "clean"
  230. branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
  231. except Exception:
  232. pass
  233. message = f"sha: {sha}, status: {diff}, branch: {branch}"
  234. return message
  235. def collate_fn(batch):
  236. batch = list(zip(*batch))
  237. batch[0] = nested_tensor_from_tensor_list(batch[0])
  238. return tuple(batch)
  239. def collate_fn_crowd(batch):
  240. # re-organize the batch
  241. batch_new = []
  242. for b in batch:
  243. imgs, points = b
  244. if imgs.ndim == 3:
  245. imgs = imgs.unsqueeze(0)
  246. for i in range(len(imgs)):
  247. batch_new.append((imgs[i, :, :, :], points[i]))
  248. batch = batch_new
  249. batch = list(zip(*batch))
  250. batch[0] = nested_tensor_from_tensor_list(batch[0])
  251. return tuple(batch)
  252. def _max_by_axis(the_list):
  253. # type: (List[List[int]]) -> List[int]
  254. maxes = the_list[0]
  255. for sublist in the_list[1:]:
  256. for index, item in enumerate(sublist):
  257. maxes[index] = max(maxes[index], item)
  258. return maxes
  259. def _max_by_axis_pad(the_list):
  260. # type: (List[List[int]]) -> List[int]
  261. maxes = the_list[0]
  262. for sublist in the_list[1:]:
  263. for index, item in enumerate(sublist):
  264. maxes[index] = max(maxes[index], item)
  265. block = 128
  266. for i in range(2):
  267. maxes[i+1] = ((maxes[i+1] - 1) // block + 1) * block
  268. return maxes
  269. def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
  270. # TODO make this more general
  271. if tensor_list[0].ndim == 3:
  272. # TODO make it support different-sized images
  273. max_size = _max_by_axis_pad([list(img.shape) for img in tensor_list])
  274. # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
  275. batch_shape = [len(tensor_list)] + max_size
  276. b, c, h, w = batch_shape
  277. dtype = tensor_list[0].dtype
  278. device = tensor_list[0].device
  279. tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
  280. for img, pad_img in zip(tensor_list, tensor):
  281. pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  282. else:
  283. raise ValueError('not supported')
  284. return tensor
  285. class NestedTensor(object):
  286. def __init__(self, tensors, mask: Optional[Tensor]):
  287. self.tensors = tensors
  288. self.mask = mask
  289. def to(self, device):
  290. # type: (Device) -> NestedTensor # noqa
  291. cast_tensor = self.tensors.to(device)
  292. mask = self.mask
  293. if mask is not None:
  294. assert mask is not None
  295. cast_mask = mask.to(device)
  296. else:
  297. cast_mask = None
  298. return NestedTensor(cast_tensor, cast_mask)
  299. def decompose(self):
  300. return self.tensors, self.mask
  301. def __repr__(self):
  302. return str(self.tensors)
  303. def setup_for_distributed(is_master):
  304. """
  305. This function disables printing when not in master process
  306. """
  307. import builtins as __builtin__
  308. builtin_print = __builtin__.print
  309. def print(*args, **kwargs):
  310. force = kwargs.pop('force', False)
  311. if is_master or force:
  312. builtin_print(*args, **kwargs)
  313. __builtin__.print = print
  314. def is_dist_avail_and_initialized():
  315. if not dist.is_available():
  316. return False
  317. if not dist.is_initialized():
  318. return False
  319. return True
  320. def get_world_size():
  321. if not is_dist_avail_and_initialized():
  322. return 1
  323. return dist.get_world_size()
  324. def get_rank():
  325. if not is_dist_avail_and_initialized():
  326. return 0
  327. return dist.get_rank()
  328. def is_main_process():
  329. return get_rank() == 0
  330. def save_on_master(*args, **kwargs):
  331. if is_main_process():
  332. torch.save(*args, **kwargs)
  333. def init_distributed_mode(args):
  334. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  335. args.rank = int(os.environ["RANK"])
  336. args.world_size = int(os.environ['WORLD_SIZE'])
  337. args.gpu = int(os.environ['LOCAL_RANK'])
  338. elif 'SLURM_PROCID' in os.environ:
  339. args.rank = int(os.environ['SLURM_PROCID'])
  340. args.gpu = args.rank % torch.cuda.device_count()
  341. else:
  342. print('Not using distributed mode')
  343. args.distributed = False
  344. return
  345. args.distributed = True
  346. torch.cuda.set_device(args.gpu)
  347. args.dist_backend = 'nccl'
  348. print('| distributed init (rank {}): {}'.format(
  349. args.rank, args.dist_url), flush=True)
  350. torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  351. world_size=args.world_size, rank=args.rank)
  352. torch.distributed.barrier()
  353. setup_for_distributed(args.rank == 0)
  354. @torch.no_grad()
  355. def accuracy(output, target, topk=(1,)):
  356. """Computes the precision@k for the specified values of k"""
  357. if target.numel() == 0:
  358. return [torch.zeros([], device=output.device)]
  359. maxk = max(topk)
  360. batch_size = target.size(0)
  361. _, pred = output.topk(maxk, 1, True, True)
  362. pred = pred.t()
  363. correct = pred.eq(target.view(1, -1).expand_as(pred))
  364. res = []
  365. for k in topk:
  366. correct_k = correct[:k].view(-1).float().sum(0)
  367. res.append(correct_k.mul_(100.0 / batch_size))
  368. return res
  369. def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
  370. # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
  371. """
  372. Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
  373. This will eventually be supported natively by PyTorch, and this
  374. class can go away.
  375. """
  376. if float(torchvision.__version__[:3]) < 0.7:
  377. if input.numel() > 0:
  378. return torch.nn.functional.interpolate(
  379. input, size, scale_factor, mode, align_corners
  380. )
  381. output_shape = _output_size(2, input, size, scale_factor)
  382. output_shape = list(input.shape[:-2]) + list(output_shape)
  383. return _new_empty_tensor(input, output_shape)
  384. else:
  385. return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
  386. class FocalLoss(nn.Module):
  387. r"""
  388. This criterion is a implemenation of Focal Loss, which is proposed in
  389. Focal Loss for Dense Object Detection.
  390. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
  391. The losses are averaged across observations for each minibatch.
  392. Args:
  393. alpha(1D Tensor, Variable) : the scalar factor for this criterion
  394. gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
  395. putting more focus on hard, misclassified examples
  396. size_average(bool): By default, the losses are averaged over observations for each minibatch.
  397. However, if the field size_average is set to False, the losses are
  398. instead summed for each minibatch.
  399. """
  400. def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
  401. super(FocalLoss, self).__init__()
  402. if alpha is None:
  403. self.alpha = Variable(torch.ones(class_num, 1))
  404. else:
  405. if isinstance(alpha, Variable):
  406. self.alpha = alpha
  407. else:
  408. self.alpha = Variable(alpha)
  409. self.gamma = gamma
  410. self.class_num = class_num
  411. self.size_average = size_average
  412. def forward(self, inputs, targets):
  413. N = inputs.size(0)
  414. C = inputs.size(1)
  415. P = F.softmax(inputs)
  416. class_mask = inputs.data.new(N, C).fill_(0)
  417. class_mask = Variable(class_mask)
  418. ids = targets.view(-1, 1)
  419. class_mask.scatter_(1, ids.data, 1.)
  420. if inputs.is_cuda and not self.alpha.is_cuda:
  421. self.alpha = self.alpha.cuda()
  422. alpha = self.alpha[ids.data.view(-1)]
  423. probs = (P*class_mask).sum(1).view(-1,1)
  424. log_p = probs.log()
  425. batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
  426. if self.size_average:
  427. loss = batch_loss.mean()
  428. else:
  429. loss = batch_loss.sum()
  430. return loss