You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.6KB

  1. import tensorrt as trt
  2. import sys,os
  3. import cv2,glob,time
  4. import torch
  5. import utils
  6. import numpy as np
  7. import torch.nn.functional as F
  8. from ocrUtils2.ocrUtils import strLabelConverter , OcrTrtForward,np_resize_keepRation
  9. class ocrModel(object):
  10. def __init__(self, weights=None,
  11. par={
  12. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  13. 'char_file':'../AIlib2/weights/conf/OCR_Ch/Ch.txt',
  14. 'mode':'ch',
  15. 'nc':3,
  16. 'imgH':32,
  17. 'imgW':256,
  18. 'hidden':256,
  19. 'mean':[0.5,0.5,0.5],
  20. 'std':[0.5,0.5,0.5],
  21. 'dynamic':False,
  22. }
  23. ):
  24. self.par = par
  25. self.device = 'cuda:0'
  26. self.half =True
  27. self.dynamic = par['dynamic']
  28. self.par['modelSize'] = (par['imgW'], par['imgH'])
  29. with open(par['char_file'], 'r') as fp:
  30. alphabet = fp.read()
  31. #self.converter = utils.strLabelConverter(alphabet)
  32. self.converter = strLabelConverter(alphabet)
  33. self.nclass = len(alphabet) + 1
  34. if weights.endswith('.engine'):
  35. self.infer_type ='trt'
  36. elif weights.endswith('.pth') or weights.endswith('.pt') :
  37. self.infer_type ='pth'
  38. else:
  39. print('#########ERROR:',weights,': no registered inference type, exit')
  40. sys.exit(0)
  41. if self.infer_type=='trt':
  42. logger = trt.Logger(trt.Logger.ERROR)
  43. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  44. self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  45. #self.context = self.model.create_execution_context()
  46. elif self.infer_type=='pth':
  47. if par['mode']=='ch':
  48. import ocrUtils2.crnnCh as crnn
  49. self.model = crnn.CRNN(par['nc'], par['hidden'], self.nclass, par['imgH'])
  50. else:
  51. import ocrUtils2.crnn_model as crnn
  52. self.model = crnn.CRNN(par['imgH'], par['nc'], self.nclass,par['hidden'] )
  53. self.load_model_weights(weights)
  54. self.model = self.model.to(self.device)
  55. print('#######load pt model:%s success '%(weights))
  56. self.par['modelType']=self.infer_type
  57. print('#########加载模型:',weights,' 类型:',self.infer_type)
  58. def eval(self,image):
  59. t0 = time.time()
  60. image = self.preprocess_image(image)
  61. t1 = time.time()
  62. if self.infer_type=='pth':
  63. self.model.eval()
  64. preds = self.model(image)
  65. else:
  66. preds,trtstr=OcrTrtForward(self.model,[image],False)
  67. t2 = time.time()
  68. preds_size = torch.IntTensor([preds.size(0)]*1)
  69. preds = F.softmax(preds, dim=2)
  70. preds_score, preds = preds.max(2)
  71. #print('##line78:',preds,preds_score)
  72. preds = preds.transpose(1, 0).contiguous().view(-1)
  73. res_real = self.converter.decode(preds, preds_size, raw=False)
  74. t3 = time.time()
  75. timeInfos = 'total:%.1f (preProcess:%.1f ,inference:%.1f, postProcess:%.1f) '%( self.get_ms(t3,t0), self.get_ms(t1,t0), self.get_ms(t2,t1), self.get_ms(t3,t2), )
  76. return res_real,timeInfos
  77. def preprocess_image(self,image):
  78. if self.par['nc']==1:
  79. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  80. else: image = image[:,:,::-1] #bgr-->rgb
  81. if self.dynamic:
  82. H,W = image.shape[0:2]
  83. image = cv2.resize(image, (0, 0), fx=self.par['modelSize'][1] / H, fy=self.par['modelSize'][1] / H, interpolation=cv2.INTER_CUBIC)
  84. else:
  85. re_size = self.par['modelSize']
  86. image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
  87. if self.infer_type=='trt':
  88. image = np_resize_keepRation(image,self.par['modelSize'][1] ,self.par['modelSize'][0] )
  89. image = image.astype(np.float32)
  90. image /= 255.0
  91. #print('####line105:',image.shape)
  92. if self.par['nc']==1:
  93. image = (image-self.par['mean'][0])/self.par['std'][0]
  94. image = np.expand_dims(image,0)
  95. else:
  96. image[:, :, 0] -= self.par['mean'][0]
  97. image[:, :, 1] -= self.par['mean'][1]
  98. image[:, :, 2] -= self.par['mean'][2]
  99. image[:, :, 0] /= self.par['std'][0]
  100. image[:, :, 1] /= self.par['std'][1]
  101. image[:, :, 2] /= self.par['std'][2]
  102. image = np.transpose(image, (2, 0, 1))
  103. image = torch.from_numpy(image).float()
  104. image = image.unsqueeze(0)
  105. if self.device != 'cpu':
  106. image = image.to(self.device)
  107. return image
  108. def get_ms(self,t1,t0):
  109. return (t1-t0)*1000.0
  110. def load_model_weights(self,weight):
  111. checkpoint = torch.load(weight)
  112. if 'state_dict' in checkpoint.keys():
  113. self.model.load_state_dict(checkpoint['state_dict'])
  114. else:
  115. try:
  116. self.model.load_state_dict(checkpoint)
  117. except:
  118. ##修正模型参数的名字
  119. state_dict = torch.load(weight)
  120. # create new OrderedDict that does not contain `module.`
  121. from collections import OrderedDict
  122. new_state_dict = OrderedDict()
  123. for k, v in state_dict.items():
  124. name = k[7:] # remove `module.`
  125. new_state_dict[name] = v
  126. # load params
  127. self.model.load_state_dict(new_state_dict)
  128. if __name__== "__main__":
  129. #weights = '/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/weights/scene_base.pth'
  130. weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_ch_2080Ti_fp16_192X32.engine'
  131. par={
  132. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  133. 'char_file':'/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/src/models/CRNN/data/benchmark.txt',
  134. 'mode':'ch',
  135. 'nc':3,
  136. 'imgH':32,
  137. 'imgW':192,
  138. 'hidden':256,
  139. 'mean':[0.5,0.5,0.5],
  140. 'std':[0.5,0.5,0.5],
  141. 'dynamic':False
  142. }
  143. inputDir = '/home/thsw2/WJ/src/OCR/shipNames'
  144. '''
  145. weights = '/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/crnn_448X32.pth'
  146. #weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_en_2080Ti_fp16_448X32.engine'
  147. par={
  148. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  149. 'char_file':'/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/chars2.txt',
  150. 'mode':'en',
  151. 'nc':1,
  152. 'imgH':32,
  153. 'imgW':448,
  154. 'hidden':256,
  155. 'mean':[0.588,0.588,0.588],
  156. 'std':[0.193,0.193,0.193 ],
  157. 'dynamic':True
  158. }
  159. inputDir='/home/thsw2/WJ/src/DSP2/AIdemo2/images/ocr_en'
  160. '''
  161. model = ocrModel(weights=weights,par=par )
  162. imgUrls = glob.glob('%s/*.jpg'%(inputDir))
  163. for imgUrl in imgUrls[0:]:
  164. img = cv2.imread(imgUrl)
  165. res_real,timeInfos = model.eval(img)
  166. res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
  167. print(res_real,os.path.basename(imgUrl),timeInfos )