落水人员检测
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

29 líneas
985B

  1. import os
  2. from tqdm import tqdm
  3. import numpy as np
  4. from mypath import Path
  5. def calculate_weigths_labels(dataset, dataloader, num_classes):
  6. # Create an instance from the data loader
  7. z = np.zeros((num_classes,))
  8. # Initialize tqdm
  9. tqdm_batch = tqdm(dataloader)
  10. print('Calculating classes weights')
  11. for sample in tqdm_batch:
  12. y = sample['label']
  13. y = y.detach().cpu().numpy()
  14. mask = (y >= 0) & (y < num_classes)
  15. labels = y[mask].astype(np.uint8)
  16. count_l = np.bincount(labels, minlength=num_classes)
  17. z += count_l
  18. tqdm_batch.close()
  19. total_frequency = np.sum(z)
  20. class_weights = []
  21. for frequency in z:
  22. class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
  23. class_weights.append(class_weight)
  24. ret = np.array(class_weights)
  25. classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy')
  26. np.save(classes_weights_path, ret)
  27. return ret