AIlib2/ocrUtils/ocrUtils.py

192 lines
7.1 KiB
Python

import torch
import numpy as np
import torchvision.transforms as transforms
import math
from PIL import Image
def custom_mean(x):
return x.prod()**(2.0/np.sqrt(len(x)))
def contrast_grey(img):
high = np.percentile(img, 90)
low = np.percentile(img, 10)
return (high-low)/np.maximum(10, high+low), high, low
def adjust_contrast_grey(img, target = 0.4):
contrast, high, low = contrast_grey(img)
if contrast < target:
img = img.astype(int)
ratio = 200./np.maximum(10, high-low)
img = (img - low + 25)*ratio
img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
return img
class NormalizePAD(object):
def __init__(self, max_size, PAD_type='right'):
self.toTensor = transforms.ToTensor()
self.max_size = max_size
self.max_width_half = math.floor(max_size[2] / 2)
self.PAD_type = PAD_type
def __call__(self, img):
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
c, h, w = img.size()
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
return Pad_img
class AlignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio_with_pad = keep_ratio_with_pad
self.adjust_contrast = adjust_contrast
def __call__(self, batch):
#print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
batch = filter(lambda x: x is not None, batch)
images = batch
resized_max_w = self.imgW
input_channel = 1
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
resized_images = []
for image in images:
w, h = image.size
#### augmentation here - change contrast
if self.adjust_contrast > 0:
image = np.array(image.convert("L"))
image = adjust_contrast_grey(image, target = self.adjust_contrast)
image = Image.fromarray(image, 'L')
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
resized_images.append(transform(resized_image))
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
return image_tensors
class CTCLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character, separator_list = {}, dict_pathlist = {}):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i + 1
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
self.separator_list = separator_list
separator_char = []
for lang, sep in separator_list.items():
separator_char += sep
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
####### latin dict
if len(separator_list) == 0:
dict_list = []
for lang, dict_path in dict_pathlist.items():
try:
with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
word_count = input_file.read().splitlines()
dict_list += word_count
except:
pass
else:
dict_list = {}
for lang, dict_path in dict_pathlist.items():
with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
word_count = input_file.read().splitlines()
dict_list[lang] = word_count
self.dict_list = dict_list
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
def decode_greedy(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
# Returns a boolean array where true is when the value is not repeated
a = np.insert(~((t[1:]==t[:-1])),0,True)
# Returns a boolean array where true is when the value is not in the ignore_idx list
b = ~np.isin(t,np.array(self.ignore_idx))
# Combine the two boolean array
c = a & b
# Gets the corresponding character according to the saved indexes
text = ''.join(np.array(self.character)[t[c.nonzero()]])
texts.append(text)
index += l
return texts
def decode_beamsearch(self, mat, beamWidth=5):
texts = []
for i in range(mat.shape[0]):
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
texts.append(t)
return texts
def decode_wordbeamsearch(self, mat, beamWidth=5):
texts = []
argmax = np.argmax(mat, axis = 2)
for i in range(mat.shape[0]):
string = ''
# without separators - use space as separator
if len(self.separator_list) == 0:
space_idx = self.dict[' ']
data = np.argwhere(argmax[i]!=space_idx).flatten()
group = np.split(data, np.where(np.diff(data) != 1)[0]+1)
group = [ list(item) for item in group if len(item)>0]
for j, list_idx in enumerate(group):
matrix = mat[i, list_idx,:]
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None,\
beamWidth=beamWidth, dict_list=self.dict_list)
if j == 0: string += t
else: string += ' '+t
# with separators
else:
words = word_segmentation(argmax[i])
for word in words:
matrix = mat[i, word[1][0]:word[1][1]+1,:]
if word[0] == '': dict_list = []
else: dict_list = self.dict_list[word[0]]
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
string += t
texts.append(string)
return texts