from torch import nn from .. import functional as F class PSAMask(nn.Module): def __init__(self, psa_type=0, mask_H_=None, mask_W_=None): super(PSAMask, self).__init__() assert psa_type in [0, 1] # 0-col, 1-dis assert (mask_H_ in None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None) self.psa_type = psa_type self.mask_H_ = mask_H_ self.mask_W_ = mask_W_ def forward(self, input): return F.psa_mask(input, self.psa_type, self.mask_H_, self.mask_W_)