STDC-th/loss/util.py

43 lines
1.4 KiB
Python

import numpy as np
import torch
def enet_weighing(label, num_classes, c=1.02):
"""Computes class weights as described in the ENet paper:
w_class = 1 / (ln(c + p_class)),
where c is usually 1.02 and p_class is the propensity score of that
class:
propensity_score = freq_class / total_pixels.
References: https://arxiv.org/abs/1606.02147
Keyword arguments:
- dataloader (``data.Dataloader``): A data loader to iterate over the
dataset.
- num_classes (``int``): The number of classes.
- c (``int``, optional): AN additional hyper-parameter which restricts
the interval of values for the weights. Default: 1.02.
"""
class_count = 0
total = 0
label = label.cpu().numpy()
# Flatten label
flat_label = label.flatten()
# Sum up the number of pixels of each class and the total pixel
# counts for each label
class_count += np.bincount(flat_label, minlength=num_classes)
total += flat_label.size
# Compute propensity score and then the weights for each class
propensity_score = class_count / total
class_weights = 1 / (np.log(c + propensity_score))
class_weights = torch.from_numpy(class_weights).float()
# print(class_weights)
return class_weights
def minmax_scale(input_arr):
min_val = np.min(input_arr)
max_val = np.max(input_arr)
output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
return output_arr