|
- import numpy as np
- import torch
-
- def enet_weighing(label, num_classes, c=1.02):
- """Computes class weights as described in the ENet paper:
- w_class = 1 / (ln(c + p_class)),
- where c is usually 1.02 and p_class is the propensity score of that
- class:
- propensity_score = freq_class / total_pixels.
- References: https://arxiv.org/abs/1606.02147
- Keyword arguments:
- - dataloader (``data.Dataloader``): A data loader to iterate over the
- dataset.
- - num_classes (``int``): The number of classes.
- - c (``int``, optional): AN additional hyper-parameter which restricts
- the interval of values for the weights. Default: 1.02.
- """
- class_count = 0
- total = 0
-
- label = label.cpu().numpy()
-
- # Flatten label
- flat_label = label.flatten()
-
- # Sum up the number of pixels of each class and the total pixel
- # counts for each label
- class_count += np.bincount(flat_label, minlength=num_classes)
- total += flat_label.size
-
- # Compute propensity score and then the weights for each class
- propensity_score = class_count / total
- class_weights = 1 / (np.log(c + propensity_score))
-
- class_weights = torch.from_numpy(class_weights).float()
- # print(class_weights)
- return class_weights
-
- def minmax_scale(input_arr):
- min_val = np.min(input_arr)
- max_val = np.max(input_arr)
- output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
- return output_arr
|