AIlib2/ocrUtils2/ocrUtils.py

315 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
import torchvision.transforms as transforms
import math, yaml
from easydict import EasyDict as edict
from PIL import Image
import cv2
from torch.autograd import Variable
import time
import tensorrt as trt
def trt_version():
return trt.__version__
def torch_device_from_trt(device):
if device == trt.TensorLocation.DEVICE:
return torch.device("cuda")
elif device == trt.TensorLocation.HOST:
return torch.device("cpu")
else:
return TypeError("%s is not supported by torch" % device)
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
elif trt_version() >= '7.0' and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError("%s is not supported by torch" % dtype)
def OcrTrtForward(engine,inputs,contextFlag=False):
t0=time.time()
#with engine.create_execution_context() as context:
if not contextFlag: context = engine.create_execution_context()
else: context=contextFlag
namess=[ engine.get_tensor_name(index) for index in range(engine.num_bindings) ]
input_names = [namess[0]];output_names=namess[1:]
batch_size = inputs[0].shape[0]
bindings = [None] * (len(input_names) + len(output_names))
t1=time.time()
# 创建输出tensor并分配内存
outputs = [None] * len(output_names)
for i, output_name in enumerate(output_names):
idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
device = torch_device_from_trt(engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
#print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
outputs[i] = output
#print('###line65:',output_name,i,idx,dtype,shape)
bindings[idx] = output.data_ptr()#绑定输出数据指针
t2=time.time()
for i, input_name in enumerate(input_names):
idx =engine.get_binding_index(input_name)
bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i]对应3个输入。但由于我们使用的是单张图片所以将3个输入全设置为相同的图片。
#print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
t3=time.time()
context.execute_v2(bindings) # 执行推理
t4=time.time()
if len(outputs) == 1:
outputs = outputs[0]
outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
return outputs[0],outstr
def np_resize_keepRation(img,inp_h, inp_w):
#print(img.shape,inp_h,inp_w)
img_h, img_w = img.shape[0:2]
fy=inp_h/img_h
keep_w = int(img_w* fy )
Rsize=( keep_w , img_h)
img = cv2.resize(img, Rsize )
#resize后是120,max是160,120-160的地方用边界的值填充
if keep_w < inp_w:
if len(img.shape)==3:
img_out = np.zeros((inp_h, inp_w,3 ),dtype=np.uint8)
img_out[:,:keep_w]=img[:,:]
for j in range(3):
img_out[:,keep_w:,j] = np.tile(img[:,keep_w-1:,j], inp_w-keep_w)
else:
img_out = np.zeros((inp_h, inp_w ),dtype=np.uint8)
img_out[:,:keep_w]=img[:,:]
img_out[:,keep_w:] = np.tile(img[:,keep_w-1:], inp_w-keep_w)
else:
img_out = cv2.resize(img,(inp_w,inp_h))
return img_out
def recognition_ocr(config, img, model, converter, device,par={}):
model_mode=par['model_mode'];contextFlag=par['contextFlag']
if len(img.shape)==3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
h, w = img.shape
# fisrt step: resize the height and width of image to (32, x)
img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
if model_mode=='trt':
img = np_resize_keepRation(img,par['imgH'], par['imgW'])
img = np.expand_dims(img,axis=2)
# normalize
img = img.astype(np.float32)
img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
img = img.transpose([2, 0, 1])
img = torch.from_numpy(img)
img = img.to(device)
img = img.view(1, *img.size())
if model_mode=='trt':
img_input = img.to('cuda:0')
time2 = time.time()
preds,trtstr=OcrTrtForward(model,[img],contextFlag)
else:
model.eval()
preds = model(img)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
return sim_pred
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
decode_flag = True if type(text[0])==bytes else False
for item in text:
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
for char in item:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
def get_alphabets(txtfile ):
print(txtfile)
with open(txtfile,'r') as fp:
lines=fp.readlines()
alphas=[x.strip() for x in lines]
return "".join(alphas)
def get_cfg(cfg,char_file):
with open(cfg, 'r') as f:
#config = yaml.load(f)
config = yaml.load(f, Loader=yaml.FullLoader)
config = edict(config)
config.DATASET.ALPHABETS = get_alphabets(char_file.strip() )
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config
def custom_mean(x):
return x.prod()**(2.0/np.sqrt(len(x)))
def contrast_grey(img):
high = np.percentile(img, 90)
low = np.percentile(img, 10)
return (high-low)/np.maximum(10, high+low), high, low
def adjust_contrast_grey(img, target = 0.4):
contrast, high, low = contrast_grey(img)
if contrast < target:
img = img.astype(int)
ratio = 200./np.maximum(10, high-low)
img = (img - low + 25)*ratio
img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
return img
class NormalizePAD(object):
def __init__(self, max_size, PAD_type='right'):
self.toTensor = transforms.ToTensor()
self.max_size = max_size
self.max_width_half = math.floor(max_size[2] / 2)
self.PAD_type = PAD_type
def __call__(self, img):
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
c, h, w = img.size()
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
return Pad_img
class AlignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio_with_pad = keep_ratio_with_pad
self.adjust_contrast = adjust_contrast
def __call__(self, batch):
#print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
batch = filter(lambda x: x is not None, batch)
images = batch
resized_max_w = self.imgW
input_channel = 1
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
resized_images = []
for image in images:
w, h = image.size
#### augmentation here - change contrast
if self.adjust_contrast > 0:
image = np.array(image.convert("L"))
image = adjust_contrast_grey(image, target = self.adjust_contrast)
image = Image.fromarray(image, 'L')
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
resized_images.append(transform(resized_image))
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
return image_tensors