交通事故检测代码
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

96 lines
3.2KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from loss.util import enet_weighing
  7. import numpy as np
  8. class OhemCELoss(nn.Module):
  9. def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
  10. super(OhemCELoss, self).__init__()
  11. self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
  12. self.n_min = n_min
  13. self.ignore_lb = ignore_lb
  14. self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
  15. def forward(self, logits, labels):
  16. N, C, H, W = logits.size()
  17. loss = self.criteria(logits, labels).view(-1)
  18. loss, _ = torch.sort(loss, descending=True)
  19. if loss[self.n_min] > self.thresh:
  20. loss = loss[loss>self.thresh]
  21. else:
  22. loss = loss[:self.n_min]
  23. return torch.mean(loss)
  24. class WeightedOhemCELoss(nn.Module):
  25. def __init__(self, thresh, n_min, num_classes, ignore_lb=255, *args, **kwargs):
  26. super(WeightedOhemCELoss, self).__init__()
  27. self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
  28. self.n_min = n_min
  29. self.ignore_lb = ignore_lb
  30. self.num_classes = num_classes
  31. # self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
  32. def forward(self, logits, labels):
  33. N, C, H, W = logits.size()
  34. criteria = nn.CrossEntropyLoss(weight=enet_weighing(labels, self.num_classes).cuda(), ignore_index=self.ignore_lb, reduction='none')
  35. loss = criteria(logits, labels).view(-1)
  36. loss, _ = torch.sort(loss, descending=True)
  37. if loss[self.n_min] > self.thresh:
  38. loss = loss[loss>self.thresh]
  39. else:
  40. loss = loss[:self.n_min]
  41. return torch.mean(loss)
  42. class SoftmaxFocalLoss(nn.Module):
  43. def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
  44. super(FocalLoss, self).__init__()
  45. self.gamma = gamma
  46. self.nll = nn.NLLLoss(ignore_index=ignore_lb)
  47. def forward(self, logits, labels):
  48. scores = F.softmax(logits, dim=1)
  49. factor = torch.pow(1.-scores, self.gamma)
  50. log_score = F.log_softmax(logits, dim=1)
  51. log_score = factor * log_score
  52. loss = self.nll(log_score, labels)
  53. return loss
  54. if __name__ == '__main__':
  55. torch.manual_seed(15)
  56. criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
  57. criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
  58. net1 = nn.Sequential(
  59. nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
  60. )
  61. net1.cuda()
  62. net1.train()
  63. net2 = nn.Sequential(
  64. nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
  65. )
  66. net2.cuda()
  67. net2.train()
  68. with torch.no_grad():
  69. inten = torch.randn(16, 3, 20, 20).cuda()
  70. lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
  71. lbs[1, :, :] = 255
  72. logits1 = net1(inten)
  73. logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
  74. logits2 = net2(inten)
  75. logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
  76. loss1 = criteria1(logits1, lbs)
  77. loss2 = criteria2(logits2, lbs)
  78. loss = loss1 + loss2
  79. print(loss.detach().cpu())
  80. loss.backward()