高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

43 lines
1.4KB

  1. import numpy as np
  2. import torch
  3. def enet_weighing(label, num_classes, c=1.02):
  4. """Computes class weights as described in the ENet paper:
  5. w_class = 1 / (ln(c + p_class)),
  6. where c is usually 1.02 and p_class is the propensity score of that
  7. class:
  8. propensity_score = freq_class / total_pixels.
  9. References: https://arxiv.org/abs/1606.02147
  10. Keyword arguments:
  11. - dataloader (``data.Dataloader``): A data loader to iterate over the
  12. dataset.
  13. - num_classes (``int``): The number of classes.
  14. - c (``int``, optional): AN additional hyper-parameter which restricts
  15. the interval of values for the weights. Default: 1.02.
  16. """
  17. class_count = 0
  18. total = 0
  19. label = label.cpu().numpy()
  20. # Flatten label
  21. flat_label = label.flatten()
  22. # Sum up the number of pixels of each class and the total pixel
  23. # counts for each label
  24. class_count += np.bincount(flat_label, minlength=num_classes)
  25. total += flat_label.size
  26. # Compute propensity score and then the weights for each class
  27. propensity_score = class_count / total
  28. class_weights = 1 / (np.log(c + propensity_score))
  29. class_weights = torch.from_numpy(class_weights).float()
  30. # print(class_weights)
  31. return class_weights
  32. def minmax_scale(input_arr):
  33. min_val = np.min(input_arr)
  34. max_val = np.max(input_arr)
  35. output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
  36. return output_arr