202 lines
7.6 KiB
Python
202 lines
7.6 KiB
Python
|
|
import tensorrt as trt
|
|||
|
|
import sys,os
|
|||
|
|
import cv2,glob,time
|
|||
|
|
import torch
|
|||
|
|
import utils
|
|||
|
|
import numpy as np
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from ocrUtils2.ocrUtils import strLabelConverter , OcrTrtForward,np_resize_keepRation
|
|||
|
|
|
|||
|
|
class ocrModel(object):
|
|||
|
|
def __init__(self, weights=None,
|
|||
|
|
par={
|
|||
|
|
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
|
|||
|
|
'char_file':'../AIlib2/weights/conf/OCR_Ch/Ch.txt',
|
|||
|
|
'mode':'ch',
|
|||
|
|
'nc':3,
|
|||
|
|
'imgH':32,
|
|||
|
|
'imgW':256,
|
|||
|
|
'hidden':256,
|
|||
|
|
'mean':[0.5,0.5,0.5],
|
|||
|
|
'std':[0.5,0.5,0.5],
|
|||
|
|
'dynamic':False,
|
|||
|
|
}
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
self.par = par
|
|||
|
|
self.device = 'cuda:0'
|
|||
|
|
self.half =True
|
|||
|
|
self.dynamic = par['dynamic']
|
|||
|
|
self.par['modelSize'] = (par['imgW'], par['imgH'])
|
|||
|
|
with open(par['char_file'], 'r') as fp:
|
|||
|
|
alphabet = fp.read()
|
|||
|
|
#self.converter = utils.strLabelConverter(alphabet)
|
|||
|
|
self.converter = strLabelConverter(alphabet)
|
|||
|
|
self.nclass = len(alphabet) + 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
if weights.endswith('.engine'):
|
|||
|
|
self.infer_type ='trt'
|
|||
|
|
elif weights.endswith('.pth') or weights.endswith('.pt') :
|
|||
|
|
self.infer_type ='pth'
|
|||
|
|
else:
|
|||
|
|
print('#########ERROR:',weights,': no registered inference type, exit')
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
if self.infer_type=='trt':
|
|||
|
|
logger = trt.Logger(trt.Logger.ERROR)
|
|||
|
|
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
|||
|
|
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
|||
|
|
#self.context = self.model.create_execution_context()
|
|||
|
|
|
|||
|
|
elif self.infer_type=='pth':
|
|||
|
|
if par['mode']=='ch':
|
|||
|
|
import ocrUtils2.crnnCh as crnn
|
|||
|
|
self.model = crnn.CRNN(par['nc'], par['hidden'], self.nclass, par['imgH'])
|
|||
|
|
else:
|
|||
|
|
import ocrUtils2.crnn_model as crnn
|
|||
|
|
self.model = crnn.CRNN(par['imgH'], par['nc'], self.nclass,par['hidden'] )
|
|||
|
|
|
|||
|
|
self.load_model_weights(weights)
|
|||
|
|
self.model = self.model.to(self.device)
|
|||
|
|
|
|||
|
|
print('#######load pt model:%s success '%(weights))
|
|||
|
|
self.par['modelType']=self.infer_type
|
|||
|
|
print('#########加载模型:',weights,' 类型:',self.infer_type)
|
|||
|
|
def eval(self,image):
|
|||
|
|
t0 = time.time()
|
|||
|
|
image = self.preprocess_image(image)
|
|||
|
|
t1 = time.time()
|
|||
|
|
if self.infer_type=='pth':
|
|||
|
|
self.model.eval()
|
|||
|
|
preds = self.model(image)
|
|||
|
|
else:
|
|||
|
|
preds,trtstr=OcrTrtForward(self.model,[image],False)
|
|||
|
|
|
|||
|
|
t2 = time.time()
|
|||
|
|
preds_size = torch.IntTensor([preds.size(0)]*1)
|
|||
|
|
preds = F.softmax(preds, dim=2)
|
|||
|
|
preds_score, preds = preds.max(2)
|
|||
|
|
#print('##line78:',preds,preds_score)
|
|||
|
|
preds = preds.transpose(1, 0).contiguous().view(-1)
|
|||
|
|
res_real = self.converter.decode(preds, preds_size, raw=False)
|
|||
|
|
t3 = time.time()
|
|||
|
|
timeInfos = 'total:%.1f (preProcess:%.1f ,inference:%.1f, postProcess:%.1f) '%( self.get_ms(t3,t0), self.get_ms(t1,t0), self.get_ms(t2,t1), self.get_ms(t3,t2), )
|
|||
|
|
|
|||
|
|
return res_real,timeInfos
|
|||
|
|
|
|||
|
|
def preprocess_image(self,image):
|
|||
|
|
|
|||
|
|
if self.par['nc']==1:
|
|||
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|||
|
|
else: image = image[:,:,::-1] #bgr-->rgb
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if self.dynamic:
|
|||
|
|
H,W = image.shape[0:2]
|
|||
|
|
image = cv2.resize(image, (0, 0), fx=self.par['modelSize'][1] / H, fy=self.par['modelSize'][1] / H, interpolation=cv2.INTER_CUBIC)
|
|||
|
|
else:
|
|||
|
|
re_size = self.par['modelSize']
|
|||
|
|
image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
|
|||
|
|
|
|||
|
|
if self.infer_type=='trt':
|
|||
|
|
image = np_resize_keepRation(image,self.par['modelSize'][1] ,self.par['modelSize'][0] )
|
|||
|
|
|
|||
|
|
image = image.astype(np.float32)
|
|||
|
|
image /= 255.0
|
|||
|
|
#print('####line105:',image.shape)
|
|||
|
|
if self.par['nc']==1:
|
|||
|
|
image = (image-self.par['mean'][0])/self.par['std'][0]
|
|||
|
|
image = np.expand_dims(image,0)
|
|||
|
|
else:
|
|||
|
|
image[:, :, 0] -= self.par['mean'][0]
|
|||
|
|
image[:, :, 1] -= self.par['mean'][1]
|
|||
|
|
image[:, :, 2] -= self.par['mean'][2]
|
|||
|
|
|
|||
|
|
image[:, :, 0] /= self.par['std'][0]
|
|||
|
|
image[:, :, 1] /= self.par['std'][1]
|
|||
|
|
image[:, :, 2] /= self.par['std'][2]
|
|||
|
|
|
|||
|
|
image = np.transpose(image, (2, 0, 1))
|
|||
|
|
image = torch.from_numpy(image).float()
|
|||
|
|
image = image.unsqueeze(0)
|
|||
|
|
if self.device != 'cpu':
|
|||
|
|
image = image.to(self.device)
|
|||
|
|
|
|||
|
|
return image
|
|||
|
|
|
|||
|
|
def get_ms(self,t1,t0):
|
|||
|
|
return (t1-t0)*1000.0
|
|||
|
|
def load_model_weights(self,weight):
|
|||
|
|
checkpoint = torch.load(weight)
|
|||
|
|
if 'state_dict' in checkpoint.keys():
|
|||
|
|
self.model.load_state_dict(checkpoint['state_dict'])
|
|||
|
|
else:
|
|||
|
|
try:
|
|||
|
|
self.model.load_state_dict(checkpoint)
|
|||
|
|
except:
|
|||
|
|
##修正模型参数的名字
|
|||
|
|
state_dict = torch.load(weight)
|
|||
|
|
# create new OrderedDict that does not contain `module.`
|
|||
|
|
from collections import OrderedDict
|
|||
|
|
new_state_dict = OrderedDict()
|
|||
|
|
for k, v in state_dict.items():
|
|||
|
|
name = k[7:] # remove `module.`
|
|||
|
|
new_state_dict[name] = v
|
|||
|
|
# load params
|
|||
|
|
self.model.load_state_dict(new_state_dict)
|
|||
|
|
|
|||
|
|
if __name__== "__main__":
|
|||
|
|
|
|||
|
|
#weights = '/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/weights/scene_base.pth'
|
|||
|
|
weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_ch_2080Ti_fp16_192X32.engine'
|
|||
|
|
par={
|
|||
|
|
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
|
|||
|
|
'char_file':'/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/src/models/CRNN/data/benchmark.txt',
|
|||
|
|
'mode':'ch',
|
|||
|
|
'nc':3,
|
|||
|
|
'imgH':32,
|
|||
|
|
'imgW':192,
|
|||
|
|
'hidden':256,
|
|||
|
|
'mean':[0.5,0.5,0.5],
|
|||
|
|
'std':[0.5,0.5,0.5],
|
|||
|
|
'dynamic':False
|
|||
|
|
}
|
|||
|
|
inputDir = '/home/thsw2/WJ/src/OCR/shipNames'
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
weights = '/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/crnn_448X32.pth'
|
|||
|
|
#weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_en_2080Ti_fp16_448X32.engine'
|
|||
|
|
par={
|
|||
|
|
#'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
|
|||
|
|
'char_file':'/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/chars2.txt',
|
|||
|
|
'mode':'en',
|
|||
|
|
'nc':1,
|
|||
|
|
'imgH':32,
|
|||
|
|
'imgW':448,
|
|||
|
|
'hidden':256,
|
|||
|
|
'mean':[0.588,0.588,0.588],
|
|||
|
|
'std':[0.193,0.193,0.193 ],
|
|||
|
|
'dynamic':True
|
|||
|
|
}
|
|||
|
|
inputDir='/home/thsw2/WJ/src/DSP2/AIdemo2/images/ocr_en'
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
|
|||
|
|
model = ocrModel(weights=weights,par=par )
|
|||
|
|
|
|||
|
|
|
|||
|
|
imgUrls = glob.glob('%s/*.jpg'%(inputDir))
|
|||
|
|
|
|||
|
|
for imgUrl in imgUrls[0:]:
|
|||
|
|
img = cv2.imread(imgUrl)
|
|||
|
|
res_real,timeInfos = model.eval(img)
|
|||
|
|
res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
|
|||
|
|
print(res_real,os.path.basename(imgUrl),timeInfos )
|
|||
|
|
|
|||
|
|
|
|||
|
|
|