192 lines
7.1 KiB
Python
192 lines
7.1 KiB
Python
import torch
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
import math
|
|
from PIL import Image
|
|
def custom_mean(x):
|
|
return x.prod()**(2.0/np.sqrt(len(x)))
|
|
|
|
def contrast_grey(img):
|
|
high = np.percentile(img, 90)
|
|
low = np.percentile(img, 10)
|
|
return (high-low)/np.maximum(10, high+low), high, low
|
|
|
|
def adjust_contrast_grey(img, target = 0.4):
|
|
contrast, high, low = contrast_grey(img)
|
|
if contrast < target:
|
|
img = img.astype(int)
|
|
ratio = 200./np.maximum(10, high-low)
|
|
img = (img - low + 25)*ratio
|
|
img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
|
|
return img
|
|
|
|
class NormalizePAD(object):
|
|
|
|
def __init__(self, max_size, PAD_type='right'):
|
|
self.toTensor = transforms.ToTensor()
|
|
self.max_size = max_size
|
|
self.max_width_half = math.floor(max_size[2] / 2)
|
|
self.PAD_type = PAD_type
|
|
|
|
def __call__(self, img):
|
|
img = self.toTensor(img)
|
|
img.sub_(0.5).div_(0.5)
|
|
c, h, w = img.size()
|
|
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
|
|
Pad_img[:, :, :w] = img # right pad
|
|
if self.max_size[2] != w: # add border Pad
|
|
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
|
|
|
|
return Pad_img
|
|
|
|
class AlignCollate(object):
|
|
|
|
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
|
|
self.imgH = imgH
|
|
self.imgW = imgW
|
|
self.keep_ratio_with_pad = keep_ratio_with_pad
|
|
self.adjust_contrast = adjust_contrast
|
|
|
|
def __call__(self, batch):
|
|
#print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
|
|
batch = filter(lambda x: x is not None, batch)
|
|
images = batch
|
|
|
|
resized_max_w = self.imgW
|
|
input_channel = 1
|
|
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
|
|
|
|
resized_images = []
|
|
for image in images:
|
|
w, h = image.size
|
|
#### augmentation here - change contrast
|
|
if self.adjust_contrast > 0:
|
|
image = np.array(image.convert("L"))
|
|
image = adjust_contrast_grey(image, target = self.adjust_contrast)
|
|
image = Image.fromarray(image, 'L')
|
|
|
|
ratio = w / float(h)
|
|
if math.ceil(self.imgH * ratio) > self.imgW:
|
|
resized_w = self.imgW
|
|
else:
|
|
resized_w = math.ceil(self.imgH * ratio)
|
|
|
|
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
|
|
resized_images.append(transform(resized_image))
|
|
|
|
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
|
|
return image_tensors
|
|
|
|
|
|
class CTCLabelConverter(object):
|
|
""" Convert between text-label and text-index """
|
|
|
|
def __init__(self, character, separator_list = {}, dict_pathlist = {}):
|
|
# character (str): set of the possible characters.
|
|
dict_character = list(character)
|
|
|
|
self.dict = {}
|
|
for i, char in enumerate(dict_character):
|
|
self.dict[char] = i + 1
|
|
|
|
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
|
|
|
|
self.separator_list = separator_list
|
|
separator_char = []
|
|
for lang, sep in separator_list.items():
|
|
separator_char += sep
|
|
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
|
|
|
|
####### latin dict
|
|
if len(separator_list) == 0:
|
|
dict_list = []
|
|
for lang, dict_path in dict_pathlist.items():
|
|
try:
|
|
with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
|
|
word_count = input_file.read().splitlines()
|
|
dict_list += word_count
|
|
except:
|
|
pass
|
|
else:
|
|
dict_list = {}
|
|
for lang, dict_path in dict_pathlist.items():
|
|
with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
|
|
word_count = input_file.read().splitlines()
|
|
dict_list[lang] = word_count
|
|
|
|
self.dict_list = dict_list
|
|
|
|
def encode(self, text, batch_max_length=25):
|
|
"""convert text-label into text-index.
|
|
input:
|
|
text: text labels of each image. [batch_size]
|
|
|
|
output:
|
|
text: concatenated text index for CTCLoss.
|
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
|
length: length of each text. [batch_size]
|
|
"""
|
|
length = [len(s) for s in text]
|
|
text = ''.join(text)
|
|
text = [self.dict[char] for char in text]
|
|
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
|
|
def decode_greedy(self, text_index, length):
|
|
""" convert text-index into text-label. """
|
|
texts = []
|
|
index = 0
|
|
for l in length:
|
|
t = text_index[index:index + l]
|
|
# Returns a boolean array where true is when the value is not repeated
|
|
a = np.insert(~((t[1:]==t[:-1])),0,True)
|
|
# Returns a boolean array where true is when the value is not in the ignore_idx list
|
|
b = ~np.isin(t,np.array(self.ignore_idx))
|
|
# Combine the two boolean array
|
|
c = a & b
|
|
# Gets the corresponding character according to the saved indexes
|
|
text = ''.join(np.array(self.character)[t[c.nonzero()]])
|
|
texts.append(text)
|
|
index += l
|
|
return texts
|
|
|
|
def decode_beamsearch(self, mat, beamWidth=5):
|
|
texts = []
|
|
for i in range(mat.shape[0]):
|
|
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
|
|
texts.append(t)
|
|
return texts
|
|
|
|
def decode_wordbeamsearch(self, mat, beamWidth=5):
|
|
texts = []
|
|
argmax = np.argmax(mat, axis = 2)
|
|
|
|
for i in range(mat.shape[0]):
|
|
string = ''
|
|
# without separators - use space as separator
|
|
if len(self.separator_list) == 0:
|
|
space_idx = self.dict[' ']
|
|
|
|
data = np.argwhere(argmax[i]!=space_idx).flatten()
|
|
group = np.split(data, np.where(np.diff(data) != 1)[0]+1)
|
|
group = [ list(item) for item in group if len(item)>0]
|
|
|
|
for j, list_idx in enumerate(group):
|
|
matrix = mat[i, list_idx,:]
|
|
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None,\
|
|
beamWidth=beamWidth, dict_list=self.dict_list)
|
|
if j == 0: string += t
|
|
else: string += ' '+t
|
|
|
|
# with separators
|
|
else:
|
|
words = word_segmentation(argmax[i])
|
|
|
|
for word in words:
|
|
matrix = mat[i, word[1][0]:word[1][1]+1,:]
|
|
if word[0] == '': dict_list = []
|
|
else: dict_list = self.dict_list[word[0]]
|
|
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
|
|
string += t
|
|
texts.append(string)
|
|
return texts |