43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
def enet_weighing(label, num_classes, c=1.02):
|
|
"""Computes class weights as described in the ENet paper:
|
|
w_class = 1 / (ln(c + p_class)),
|
|
where c is usually 1.02 and p_class is the propensity score of that
|
|
class:
|
|
propensity_score = freq_class / total_pixels.
|
|
References: https://arxiv.org/abs/1606.02147
|
|
Keyword arguments:
|
|
- dataloader (``data.Dataloader``): A data loader to iterate over the
|
|
dataset.
|
|
- num_classes (``int``): The number of classes.
|
|
- c (``int``, optional): AN additional hyper-parameter which restricts
|
|
the interval of values for the weights. Default: 1.02.
|
|
"""
|
|
class_count = 0
|
|
total = 0
|
|
|
|
label = label.cpu().numpy()
|
|
|
|
# Flatten label
|
|
flat_label = label.flatten()
|
|
|
|
# Sum up the number of pixels of each class and the total pixel
|
|
# counts for each label
|
|
class_count += np.bincount(flat_label, minlength=num_classes)
|
|
total += flat_label.size
|
|
|
|
# Compute propensity score and then the weights for each class
|
|
propensity_score = class_count / total
|
|
class_weights = 1 / (np.log(c + propensity_score))
|
|
|
|
class_weights = torch.from_numpy(class_weights).float()
|
|
# print(class_weights)
|
|
return class_weights
|
|
|
|
def minmax_scale(input_arr):
|
|
min_val = np.min(input_arr)
|
|
max_val = np.max(input_arr)
|
|
output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
|
|
return output_arr |