import tensorrt as trt import sys,os import cv2,glob,time import torch import utils import numpy as np import torch.nn.functional as F from ocrUtils2.ocrUtils import strLabelConverter , OcrTrtForward,np_resize_keepRation class ocrModel(object): def __init__(self, weights=None, par={ #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml', 'char_file':'../AIlib2/weights/conf/OCR_Ch/Ch.txt', 'mode':'ch', 'nc':3, 'imgH':32, 'imgW':256, 'hidden':256, 'mean':[0.5,0.5,0.5], 'std':[0.5,0.5,0.5], 'dynamic':False, } ): self.par = par self.device = 'cuda:0' self.half =True self.dynamic = par['dynamic'] self.par['modelSize'] = (par['imgW'], par['imgH']) with open(par['char_file'], 'r') as fp: alphabet = fp.read() #self.converter = utils.strLabelConverter(alphabet) self.converter = strLabelConverter(alphabet) self.nclass = len(alphabet) + 1 if weights.endswith('.engine'): self.infer_type ='trt' elif weights.endswith('.pth') or weights.endswith('.pt') : self.infer_type ='pth' else: print('#########ERROR:',weights,': no registered inference type, exit') sys.exit(0) if self.infer_type=='trt': logger = trt.Logger(trt.Logger.ERROR) with open(weights, "rb") as f, trt.Runtime(logger) as runtime: self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象 #self.context = self.model.create_execution_context() elif self.infer_type=='pth': if par['mode']=='ch': import ocrUtils2.crnnCh as crnn self.model = crnn.CRNN(par['nc'], par['hidden'], self.nclass, par['imgH']) else: import ocrUtils2.crnn_model as crnn self.model = crnn.CRNN(par['imgH'], par['nc'], self.nclass,par['hidden'] ) self.load_model_weights(weights) self.model = self.model.to(self.device) print('#######load pt model:%s success '%(weights)) self.par['modelType']=self.infer_type print('#########加载模型:',weights,' 类型:',self.infer_type) def eval(self,image): t0 = time.time() image = self.preprocess_image(image) t1 = time.time() if self.infer_type=='pth': self.model.eval() preds = self.model(image) else: preds,trtstr=OcrTrtForward(self.model,[image],False) t2 = time.time() preds_size = torch.IntTensor([preds.size(0)]*1) preds = F.softmax(preds, dim=2) preds_score, preds = preds.max(2) #print('##line78:',preds,preds_score) preds = preds.transpose(1, 0).contiguous().view(-1) res_real = self.converter.decode(preds, preds_size, raw=False) t3 = time.time() timeInfos = 'total:%.1f (preProcess:%.1f ,inference:%.1f, postProcess:%.1f) '%( self.get_ms(t3,t0), self.get_ms(t1,t0), self.get_ms(t2,t1), self.get_ms(t3,t2), ) return res_real,timeInfos def preprocess_image(self,image): if self.par['nc']==1: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: image = image[:,:,::-1] #bgr-->rgb if self.dynamic: H,W = image.shape[0:2] image = cv2.resize(image, (0, 0), fx=self.par['modelSize'][1] / H, fy=self.par['modelSize'][1] / H, interpolation=cv2.INTER_CUBIC) else: re_size = self.par['modelSize'] image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR) if self.infer_type=='trt': image = np_resize_keepRation(image,self.par['modelSize'][1] ,self.par['modelSize'][0] ) image = image.astype(np.float32) image /= 255.0 #print('####line105:',image.shape) if self.par['nc']==1: image = (image-self.par['mean'][0])/self.par['std'][0] image = np.expand_dims(image,0) else: image[:, :, 0] -= self.par['mean'][0] image[:, :, 1] -= self.par['mean'][1] image[:, :, 2] -= self.par['mean'][2] image[:, :, 0] /= self.par['std'][0] image[:, :, 1] /= self.par['std'][1] image[:, :, 2] /= self.par['std'][2] image = np.transpose(image, (2, 0, 1)) image = torch.from_numpy(image).float() image = image.unsqueeze(0) if self.device != 'cpu': image = image.to(self.device) return image def get_ms(self,t1,t0): return (t1-t0)*1000.0 def load_model_weights(self,weight): checkpoint = torch.load(weight) if 'state_dict' in checkpoint.keys(): self.model.load_state_dict(checkpoint['state_dict']) else: try: self.model.load_state_dict(checkpoint) except: ##修正模型参数的名字 state_dict = torch.load(weight) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params self.model.load_state_dict(new_state_dict) if __name__== "__main__": #weights = '/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/weights/scene_base.pth' weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_ch_2080Ti_fp16_192X32.engine' par={ #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml', 'char_file':'/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/src/models/CRNN/data/benchmark.txt', 'mode':'ch', 'nc':3, 'imgH':32, 'imgW':192, 'hidden':256, 'mean':[0.5,0.5,0.5], 'std':[0.5,0.5,0.5], 'dynamic':False } inputDir = '/home/thsw2/WJ/src/OCR/shipNames' ''' weights = '/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/crnn_448X32.pth' #weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_en_2080Ti_fp16_448X32.engine' par={ #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml', 'char_file':'/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/chars2.txt', 'mode':'en', 'nc':1, 'imgH':32, 'imgW':448, 'hidden':256, 'mean':[0.588,0.588,0.588], 'std':[0.193,0.193,0.193 ], 'dynamic':True } inputDir='/home/thsw2/WJ/src/DSP2/AIdemo2/images/ocr_en' ''' model = ocrModel(weights=weights,par=par ) imgUrls = glob.glob('%s/*.jpg'%(inputDir)) for imgUrl in imgUrls[0:]: img = cv2.imread(imgUrl) res_real,timeInfos = model.eval(img) res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real))) print(res_real,os.path.basename(imgUrl),timeInfos )