AIlib2/ocr.py

202 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tensorrt as trt
import sys,os
import cv2,glob,time
import torch
import utils
import numpy as np
import torch.nn.functional as F
from ocrUtils2.ocrUtils import strLabelConverter , OcrTrtForward,np_resize_keepRation
class ocrModel(object):
def __init__(self, weights=None,
par={
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
'char_file':'../AIlib2/weights/conf/OCR_Ch/Ch.txt',
'mode':'ch',
'nc':3,
'imgH':32,
'imgW':256,
'hidden':256,
'mean':[0.5,0.5,0.5],
'std':[0.5,0.5,0.5],
'dynamic':False,
}
):
self.par = par
self.device = 'cuda:0'
self.half =True
self.dynamic = par['dynamic']
self.par['modelSize'] = (par['imgW'], par['imgH'])
with open(par['char_file'], 'r') as fp:
alphabet = fp.read()
#self.converter = utils.strLabelConverter(alphabet)
self.converter = strLabelConverter(alphabet)
self.nclass = len(alphabet) + 1
if weights.endswith('.engine'):
self.infer_type ='trt'
elif weights.endswith('.pth') or weights.endswith('.pt') :
self.infer_type ='pth'
else:
print('#########ERROR:',weights,': no registered inference type, exit')
sys.exit(0)
if self.infer_type=='trt':
logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
#self.context = self.model.create_execution_context()
elif self.infer_type=='pth':
if par['mode']=='ch':
import ocrUtils2.crnnCh as crnn
self.model = crnn.CRNN(par['nc'], par['hidden'], self.nclass, par['imgH'])
else:
import ocrUtils2.crnn_model as crnn
self.model = crnn.CRNN(par['imgH'], par['nc'], self.nclass,par['hidden'] )
self.load_model_weights(weights)
self.model = self.model.to(self.device)
print('#######load pt model:%s success '%(weights))
self.par['modelType']=self.infer_type
print('#########加载模型:',weights,' 类型:',self.infer_type)
def eval(self,image):
t0 = time.time()
image = self.preprocess_image(image)
t1 = time.time()
if self.infer_type=='pth':
self.model.eval()
preds = self.model(image)
else:
preds,trtstr=OcrTrtForward(self.model,[image],False)
t2 = time.time()
preds_size = torch.IntTensor([preds.size(0)]*1)
preds = F.softmax(preds, dim=2)
preds_score, preds = preds.max(2)
#print('##line78:',preds,preds_score)
preds = preds.transpose(1, 0).contiguous().view(-1)
res_real = self.converter.decode(preds, preds_size, raw=False)
t3 = time.time()
timeInfos = 'total:%.1f (preProcess:%.1f ,inference:%.1f, postProcess:%.1f) '%( self.get_ms(t3,t0), self.get_ms(t1,t0), self.get_ms(t2,t1), self.get_ms(t3,t2), )
return res_real,timeInfos
def preprocess_image(self,image):
if self.par['nc']==1:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else: image = image[:,:,::-1] #bgr-->rgb
if self.dynamic:
H,W = image.shape[0:2]
image = cv2.resize(image, (0, 0), fx=self.par['modelSize'][1] / H, fy=self.par['modelSize'][1] / H, interpolation=cv2.INTER_CUBIC)
else:
re_size = self.par['modelSize']
image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
if self.infer_type=='trt':
image = np_resize_keepRation(image,self.par['modelSize'][1] ,self.par['modelSize'][0] )
image = image.astype(np.float32)
image /= 255.0
#print('####line105:',image.shape)
if self.par['nc']==1:
image = (image-self.par['mean'][0])/self.par['std'][0]
image = np.expand_dims(image,0)
else:
image[:, :, 0] -= self.par['mean'][0]
image[:, :, 1] -= self.par['mean'][1]
image[:, :, 2] -= self.par['mean'][2]
image[:, :, 0] /= self.par['std'][0]
image[:, :, 1] /= self.par['std'][1]
image[:, :, 2] /= self.par['std'][2]
image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
if self.device != 'cpu':
image = image.to(self.device)
return image
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def load_model_weights(self,weight):
checkpoint = torch.load(weight)
if 'state_dict' in checkpoint.keys():
self.model.load_state_dict(checkpoint['state_dict'])
else:
try:
self.model.load_state_dict(checkpoint)
except:
##修正模型参数的名字
state_dict = torch.load(weight)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
self.model.load_state_dict(new_state_dict)
if __name__== "__main__":
#weights = '/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/weights/scene_base.pth'
weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_ch_2080Ti_fp16_192X32.engine'
par={
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
'char_file':'/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/src/models/CRNN/data/benchmark.txt',
'mode':'ch',
'nc':3,
'imgH':32,
'imgW':192,
'hidden':256,
'mean':[0.5,0.5,0.5],
'std':[0.5,0.5,0.5],
'dynamic':False
}
inputDir = '/home/thsw2/WJ/src/OCR/shipNames'
'''
weights = '/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/crnn_448X32.pth'
#weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_en_2080Ti_fp16_448X32.engine'
par={
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
'char_file':'/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/chars2.txt',
'mode':'en',
'nc':1,
'imgH':32,
'imgW':448,
'hidden':256,
'mean':[0.588,0.588,0.588],
'std':[0.193,0.193,0.193 ],
'dynamic':True
}
inputDir='/home/thsw2/WJ/src/DSP2/AIdemo2/images/ocr_en'
'''
model = ocrModel(weights=weights,par=par )
imgUrls = glob.glob('%s/*.jpg'%(inputDir))
for imgUrl in imgUrls[0:]:
img = cv2.imread(imgUrl)
res_real,timeInfos = model.eval(img)
res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
print(res_real,os.path.basename(imgUrl),timeInfos )