29 lines
985 B
Python
29 lines
985 B
Python
import os
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
from mypath import Path
|
|
|
|
def calculate_weigths_labels(dataset, dataloader, num_classes):
|
|
# Create an instance from the data loader
|
|
z = np.zeros((num_classes,))
|
|
# Initialize tqdm
|
|
tqdm_batch = tqdm(dataloader)
|
|
print('Calculating classes weights')
|
|
for sample in tqdm_batch:
|
|
y = sample['label']
|
|
y = y.detach().cpu().numpy()
|
|
mask = (y >= 0) & (y < num_classes)
|
|
labels = y[mask].astype(np.uint8)
|
|
count_l = np.bincount(labels, minlength=num_classes)
|
|
z += count_l
|
|
tqdm_batch.close()
|
|
total_frequency = np.sum(z)
|
|
class_weights = []
|
|
for frequency in z:
|
|
class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
|
|
class_weights.append(class_weight)
|
|
ret = np.array(class_weights)
|
|
classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy')
|
|
np.save(classes_weights_path, ret)
|
|
|
|
return ret |