|
- import torch
- import numpy as np
- import torchvision.transforms as transforms
- import math
- from PIL import Image
- def custom_mean(x):
- return x.prod()**(2.0/np.sqrt(len(x)))
-
- def contrast_grey(img):
- high = np.percentile(img, 90)
- low = np.percentile(img, 10)
- return (high-low)/np.maximum(10, high+low), high, low
-
- def adjust_contrast_grey(img, target = 0.4):
- contrast, high, low = contrast_grey(img)
- if contrast < target:
- img = img.astype(int)
- ratio = 200./np.maximum(10, high-low)
- img = (img - low + 25)*ratio
- img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
- return img
-
- class NormalizePAD(object):
-
- def __init__(self, max_size, PAD_type='right'):
- self.toTensor = transforms.ToTensor()
- self.max_size = max_size
- self.max_width_half = math.floor(max_size[2] / 2)
- self.PAD_type = PAD_type
-
- def __call__(self, img):
- img = self.toTensor(img)
- img.sub_(0.5).div_(0.5)
- c, h, w = img.size()
- Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
- Pad_img[:, :, :w] = img # right pad
- if self.max_size[2] != w: # add border Pad
- Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
-
- return Pad_img
-
- class AlignCollate(object):
-
- def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
- self.imgH = imgH
- self.imgW = imgW
- self.keep_ratio_with_pad = keep_ratio_with_pad
- self.adjust_contrast = adjust_contrast
-
- def __call__(self, batch):
- #print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
- batch = filter(lambda x: x is not None, batch)
- images = batch
-
- resized_max_w = self.imgW
- input_channel = 1
- transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
-
- resized_images = []
- for image in images:
- w, h = image.size
- #### augmentation here - change contrast
- if self.adjust_contrast > 0:
- image = np.array(image.convert("L"))
- image = adjust_contrast_grey(image, target = self.adjust_contrast)
- image = Image.fromarray(image, 'L')
-
- ratio = w / float(h)
- if math.ceil(self.imgH * ratio) > self.imgW:
- resized_w = self.imgW
- else:
- resized_w = math.ceil(self.imgH * ratio)
-
- resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
- resized_images.append(transform(resized_image))
-
- image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
- return image_tensors
-
-
- class CTCLabelConverter(object):
- """ Convert between text-label and text-index """
-
- def __init__(self, character, separator_list = {}, dict_pathlist = {}):
- # character (str): set of the possible characters.
- dict_character = list(character)
-
- self.dict = {}
- for i, char in enumerate(dict_character):
- self.dict[char] = i + 1
-
- self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
-
- self.separator_list = separator_list
- separator_char = []
- for lang, sep in separator_list.items():
- separator_char += sep
- self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
-
- ####### latin dict
- if len(separator_list) == 0:
- dict_list = []
- for lang, dict_path in dict_pathlist.items():
- try:
- with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
- word_count = input_file.read().splitlines()
- dict_list += word_count
- except:
- pass
- else:
- dict_list = {}
- for lang, dict_path in dict_pathlist.items():
- with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
- word_count = input_file.read().splitlines()
- dict_list[lang] = word_count
-
- self.dict_list = dict_list
-
- def encode(self, text, batch_max_length=25):
- """convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
-
- output:
- text: concatenated text index for CTCLoss.
- [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
- length: length of each text. [batch_size]
- """
- length = [len(s) for s in text]
- text = ''.join(text)
- text = [self.dict[char] for char in text]
-
- return (torch.IntTensor(text), torch.IntTensor(length))
-
- def decode_greedy(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- index = 0
- for l in length:
- t = text_index[index:index + l]
- # Returns a boolean array where true is when the value is not repeated
- a = np.insert(~((t[1:]==t[:-1])),0,True)
- # Returns a boolean array where true is when the value is not in the ignore_idx list
- b = ~np.isin(t,np.array(self.ignore_idx))
- # Combine the two boolean array
- c = a & b
- # Gets the corresponding character according to the saved indexes
- text = ''.join(np.array(self.character)[t[c.nonzero()]])
- texts.append(text)
- index += l
- return texts
-
- def decode_beamsearch(self, mat, beamWidth=5):
- texts = []
- for i in range(mat.shape[0]):
- t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
- texts.append(t)
- return texts
-
- def decode_wordbeamsearch(self, mat, beamWidth=5):
- texts = []
- argmax = np.argmax(mat, axis = 2)
-
- for i in range(mat.shape[0]):
- string = ''
- # without separators - use space as separator
- if len(self.separator_list) == 0:
- space_idx = self.dict[' ']
-
- data = np.argwhere(argmax[i]!=space_idx).flatten()
- group = np.split(data, np.where(np.diff(data) != 1)[0]+1)
- group = [ list(item) for item in group if len(item)>0]
-
- for j, list_idx in enumerate(group):
- matrix = mat[i, list_idx,:]
- t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None,\
- beamWidth=beamWidth, dict_list=self.dict_list)
- if j == 0: string += t
- else: string += ' '+t
-
- # with separators
- else:
- words = word_segmentation(argmax[i])
-
- for word in words:
- matrix = mat[i, word[1][0]:word[1][1]+1,:]
- if word[0] == '': dict_list = []
- else: dict_list = self.dict_list[word[0]]
- t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
- string += t
- texts.append(string)
- return texts
|