315 lines
12 KiB
Python
315 lines
12 KiB
Python
import torch
|
||
import numpy as np
|
||
import torchvision.transforms as transforms
|
||
import math, yaml
|
||
from easydict import EasyDict as edict
|
||
from PIL import Image
|
||
import cv2
|
||
from torch.autograd import Variable
|
||
import time
|
||
import tensorrt as trt
|
||
def trt_version():
|
||
return trt.__version__
|
||
|
||
def torch_device_from_trt(device):
|
||
if device == trt.TensorLocation.DEVICE:
|
||
return torch.device("cuda")
|
||
elif device == trt.TensorLocation.HOST:
|
||
return torch.device("cpu")
|
||
else:
|
||
return TypeError("%s is not supported by torch" % device)
|
||
|
||
|
||
def torch_dtype_from_trt(dtype):
|
||
if dtype == trt.int8:
|
||
return torch.int8
|
||
elif trt_version() >= '7.0' and dtype == trt.bool:
|
||
return torch.bool
|
||
elif dtype == trt.int32:
|
||
return torch.int32
|
||
elif dtype == trt.float16:
|
||
return torch.float16
|
||
elif dtype == trt.float32:
|
||
return torch.float32
|
||
else:
|
||
raise TypeError("%s is not supported by torch" % dtype)
|
||
|
||
def OcrTrtForward(engine,inputs,contextFlag=False):
|
||
|
||
t0=time.time()
|
||
#with engine.create_execution_context() as context:
|
||
if not contextFlag: context = engine.create_execution_context()
|
||
else: context=contextFlag
|
||
|
||
namess=[ engine.get_tensor_name(index) for index in range(engine.num_bindings) ]
|
||
input_names = [namess[0]];output_names=namess[1:]
|
||
|
||
batch_size = inputs[0].shape[0]
|
||
bindings = [None] * (len(input_names) + len(output_names))
|
||
t1=time.time()
|
||
# 创建输出tensor,并分配内存
|
||
outputs = [None] * len(output_names)
|
||
for i, output_name in enumerate(output_names):
|
||
idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
|
||
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
|
||
shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
|
||
device = torch_device_from_trt(engine.get_location(idx))
|
||
output = torch.empty(size=shape, dtype=dtype, device=device)
|
||
#print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
|
||
outputs[i] = output
|
||
#print('###line65:',output_name,i,idx,dtype,shape)
|
||
bindings[idx] = output.data_ptr()#绑定输出数据指针
|
||
t2=time.time()
|
||
|
||
for i, input_name in enumerate(input_names):
|
||
idx =engine.get_binding_index(input_name)
|
||
bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
|
||
#print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
|
||
t3=time.time()
|
||
context.execute_v2(bindings) # 执行推理
|
||
t4=time.time()
|
||
|
||
|
||
if len(outputs) == 1:
|
||
outputs = outputs[0]
|
||
outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
|
||
return outputs[0],outstr
|
||
|
||
def np_resize_keepRation(img,inp_h, inp_w):
|
||
#print(img.shape,inp_h,inp_w)
|
||
img_h, img_w = img.shape[0:2]
|
||
|
||
|
||
fy=inp_h/img_h
|
||
keep_w = int(img_w* fy )
|
||
Rsize=( keep_w , img_h)
|
||
img = cv2.resize(img, Rsize )
|
||
#resize后是120,max是160,120-160的地方用边界的值填充
|
||
if keep_w < inp_w:
|
||
if len(img.shape)==3:
|
||
img_out = np.zeros((inp_h, inp_w,3 ),dtype=np.uint8)
|
||
img_out[:,:keep_w]=img[:,:]
|
||
for j in range(3):
|
||
img_out[:,keep_w:,j] = np.tile(img[:,keep_w-1:,j], inp_w-keep_w)
|
||
else:
|
||
img_out = np.zeros((inp_h, inp_w ),dtype=np.uint8)
|
||
img_out[:,:keep_w]=img[:,:]
|
||
|
||
img_out[:,keep_w:] = np.tile(img[:,keep_w-1:], inp_w-keep_w)
|
||
else:
|
||
img_out = cv2.resize(img,(inp_w,inp_h))
|
||
return img_out
|
||
|
||
def recognition_ocr(config, img, model, converter, device,par={}):
|
||
model_mode=par['model_mode'];contextFlag=par['contextFlag']
|
||
if len(img.shape)==3:
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
# github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
|
||
h, w = img.shape
|
||
# fisrt step: resize the height and width of image to (32, x)
|
||
|
||
img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
|
||
if model_mode=='trt':
|
||
img = np_resize_keepRation(img,par['imgH'], par['imgW'])
|
||
img = np.expand_dims(img,axis=2)
|
||
|
||
|
||
# normalize
|
||
img = img.astype(np.float32)
|
||
img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
|
||
img = img.transpose([2, 0, 1])
|
||
img = torch.from_numpy(img)
|
||
|
||
img = img.to(device)
|
||
img = img.view(1, *img.size())
|
||
|
||
|
||
if model_mode=='trt':
|
||
img_input = img.to('cuda:0')
|
||
time2 = time.time()
|
||
preds,trtstr=OcrTrtForward(model,[img],contextFlag)
|
||
else:
|
||
model.eval()
|
||
preds = model(img)
|
||
_, preds = preds.max(2)
|
||
preds = preds.transpose(1, 0).contiguous().view(-1)
|
||
|
||
preds_size = Variable(torch.IntTensor([preds.size(0)]))
|
||
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
|
||
return sim_pred
|
||
class strLabelConverter(object):
|
||
"""Convert between str and label.
|
||
|
||
NOTE:
|
||
Insert `blank` to the alphabet for CTC.
|
||
|
||
Args:
|
||
alphabet (str): set of the possible characters.
|
||
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
||
"""
|
||
|
||
def __init__(self, alphabet, ignore_case=False):
|
||
self._ignore_case = ignore_case
|
||
if self._ignore_case:
|
||
alphabet = alphabet.lower()
|
||
self.alphabet = alphabet + '-' # for `-1` index
|
||
|
||
self.dict = {}
|
||
for i, char in enumerate(alphabet):
|
||
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
|
||
self.dict[char] = i + 1
|
||
|
||
def encode(self, text):
|
||
"""Support batch or single str.
|
||
|
||
Args:
|
||
text (str or list of str): texts to convert.
|
||
|
||
Returns:
|
||
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
||
torch.IntTensor [n]: length of each text.
|
||
"""
|
||
|
||
length = []
|
||
result = []
|
||
decode_flag = True if type(text[0])==bytes else False
|
||
|
||
for item in text:
|
||
|
||
if decode_flag:
|
||
item = item.decode('utf-8','strict')
|
||
length.append(len(item))
|
||
for char in item:
|
||
index = self.dict[char]
|
||
result.append(index)
|
||
text = result
|
||
return (torch.IntTensor(text), torch.IntTensor(length))
|
||
|
||
def decode(self, t, length, raw=False):
|
||
"""Decode encoded texts back into strs.
|
||
|
||
Args:
|
||
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
||
torch.IntTensor [n]: length of each text.
|
||
|
||
Raises:
|
||
AssertionError: when the texts and its length does not match.
|
||
|
||
Returns:
|
||
text (str or list of str): texts to convert.
|
||
"""
|
||
if length.numel() == 1:
|
||
length = length[0]
|
||
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
|
||
if raw:
|
||
return ''.join([self.alphabet[i - 1] for i in t])
|
||
else:
|
||
char_list = []
|
||
for i in range(length):
|
||
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
||
char_list.append(self.alphabet[t[i] - 1])
|
||
return ''.join(char_list)
|
||
else:
|
||
# batch mode
|
||
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
|
||
texts = []
|
||
index = 0
|
||
for i in range(length.numel()):
|
||
l = length[i]
|
||
texts.append(
|
||
self.decode(
|
||
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
||
index += l
|
||
return texts
|
||
|
||
|
||
def get_alphabets(txtfile ):
|
||
print(txtfile)
|
||
with open(txtfile,'r') as fp:
|
||
lines=fp.readlines()
|
||
alphas=[x.strip() for x in lines]
|
||
return "".join(alphas)
|
||
def get_cfg(cfg,char_file):
|
||
with open(cfg, 'r') as f:
|
||
#config = yaml.load(f)
|
||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||
config = edict(config)
|
||
config.DATASET.ALPHABETS = get_alphabets(char_file.strip() )
|
||
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
|
||
return config
|
||
def custom_mean(x):
|
||
return x.prod()**(2.0/np.sqrt(len(x)))
|
||
|
||
def contrast_grey(img):
|
||
high = np.percentile(img, 90)
|
||
low = np.percentile(img, 10)
|
||
return (high-low)/np.maximum(10, high+low), high, low
|
||
|
||
def adjust_contrast_grey(img, target = 0.4):
|
||
contrast, high, low = contrast_grey(img)
|
||
if contrast < target:
|
||
img = img.astype(int)
|
||
ratio = 200./np.maximum(10, high-low)
|
||
img = (img - low + 25)*ratio
|
||
img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
|
||
return img
|
||
|
||
class NormalizePAD(object):
|
||
|
||
def __init__(self, max_size, PAD_type='right'):
|
||
self.toTensor = transforms.ToTensor()
|
||
self.max_size = max_size
|
||
self.max_width_half = math.floor(max_size[2] / 2)
|
||
self.PAD_type = PAD_type
|
||
|
||
def __call__(self, img):
|
||
img = self.toTensor(img)
|
||
img.sub_(0.5).div_(0.5)
|
||
c, h, w = img.size()
|
||
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
|
||
Pad_img[:, :, :w] = img # right pad
|
||
if self.max_size[2] != w: # add border Pad
|
||
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
|
||
|
||
return Pad_img
|
||
|
||
class AlignCollate(object):
|
||
|
||
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
|
||
self.imgH = imgH
|
||
self.imgW = imgW
|
||
self.keep_ratio_with_pad = keep_ratio_with_pad
|
||
self.adjust_contrast = adjust_contrast
|
||
|
||
def __call__(self, batch):
|
||
#print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
|
||
batch = filter(lambda x: x is not None, batch)
|
||
images = batch
|
||
|
||
resized_max_w = self.imgW
|
||
input_channel = 1
|
||
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
|
||
|
||
resized_images = []
|
||
for image in images:
|
||
w, h = image.size
|
||
#### augmentation here - change contrast
|
||
if self.adjust_contrast > 0:
|
||
image = np.array(image.convert("L"))
|
||
image = adjust_contrast_grey(image, target = self.adjust_contrast)
|
||
image = Image.fromarray(image, 'L')
|
||
|
||
ratio = w / float(h)
|
||
if math.ceil(self.imgH * ratio) > self.imgW:
|
||
resized_w = self.imgW
|
||
else:
|
||
resized_w = math.ceil(self.imgH * ratio)
|
||
|
||
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
|
||
resized_images.append(transform(resized_image))
|
||
|
||
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
|
||
return image_tensors
|
||
|
||
|