落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
985B

  1. import os
  2. from tqdm import tqdm
  3. import numpy as np
  4. from mypath import Path
  5. def calculate_weigths_labels(dataset, dataloader, num_classes):
  6. # Create an instance from the data loader
  7. z = np.zeros((num_classes,))
  8. # Initialize tqdm
  9. tqdm_batch = tqdm(dataloader)
  10. print('Calculating classes weights')
  11. for sample in tqdm_batch:
  12. y = sample['label']
  13. y = y.detach().cpu().numpy()
  14. mask = (y >= 0) & (y < num_classes)
  15. labels = y[mask].astype(np.uint8)
  16. count_l = np.bincount(labels, minlength=num_classes)
  17. z += count_l
  18. tqdm_batch.close()
  19. total_frequency = np.sum(z)
  20. class_weights = []
  21. for frequency in z:
  22. class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
  23. class_weights.append(class_weight)
  24. ret = np.array(class_weights)
  25. classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy')
  26. np.save(classes_weights_path, ret)
  27. return ret