落水人员检测
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

29 行
985B

  1. import os
  2. from tqdm import tqdm
  3. import numpy as np
  4. from mypath import Path
  5. def calculate_weigths_labels(dataset, dataloader, num_classes):
  6. # Create an instance from the data loader
  7. z = np.zeros((num_classes,))
  8. # Initialize tqdm
  9. tqdm_batch = tqdm(dataloader)
  10. print('Calculating classes weights')
  11. for sample in tqdm_batch:
  12. y = sample['label']
  13. y = y.detach().cpu().numpy()
  14. mask = (y >= 0) & (y < num_classes)
  15. labels = y[mask].astype(np.uint8)
  16. count_l = np.bincount(labels, minlength=num_classes)
  17. z += count_l
  18. tqdm_batch.close()
  19. total_frequency = np.sum(z)
  20. class_weights = []
  21. for frequency in z:
  22. class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
  23. class_weights.append(class_weight)
  24. ret = np.array(class_weights)
  25. classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy')
  26. np.save(classes_weights_path, ret)
  27. return ret