You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
7.6KB

  1. import tensorrt as trt
  2. import sys,os
  3. import cv2,glob,time
  4. import torch
  5. import utils
  6. import numpy as np
  7. from ocrUtils2.ocrUtils import strLabelConverter , OcrTrtForward,np_resize_keepRation
  8. class ocrModel(object):
  9. def __init__(self, weights=None,
  10. par={
  11. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  12. 'char_file':'../AIlib2/weights/conf/OCR_Ch/Ch.txt',
  13. 'mode':'ch',
  14. 'nc':3,
  15. 'imgH':32,
  16. 'imgW':256,
  17. 'hidden':256,
  18. 'mean':[0.5,0.5,0.5],
  19. 'std':[0.5,0.5,0.5],
  20. 'dynamic':False,
  21. }
  22. ):
  23. self.par = par
  24. self.device = 'cuda:0'
  25. self.half =True
  26. self.dynamic = par['dynamic']
  27. self.par['modelSize'] = (par['imgW'], par['imgH'])
  28. with open(par['char_file'], 'r') as fp:
  29. alphabet = fp.read()
  30. #self.converter = utils.strLabelConverter(alphabet)
  31. self.converter = strLabelConverter(alphabet)
  32. self.nclass = len(alphabet) + 1
  33. if weights.endswith('.engine'):
  34. self.infer_type ='trt'
  35. elif weights.endswith('.pth') or weights.endswith('.pt') :
  36. self.infer_type ='pth'
  37. else:
  38. print('#########ERROR:',weights,': no registered inference type, exit')
  39. sys.exit(0)
  40. if self.infer_type=='trt':
  41. logger = trt.Logger(trt.Logger.ERROR)
  42. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  43. self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  44. self.context = self.model.create_execution_context()
  45. elif self.infer_type=='pth':
  46. if par['mode']=='ch':
  47. import ocrUtils2.crnnCh as crnn
  48. self.model = crnn.CRNN(par['nc'], par['hidden'], self.nclass, par['imgH'])
  49. else:
  50. import ocrUtils2.crnn_model as crnn
  51. self.model = crnn.CRNN(par['imgH'], par['nc'], self.nclass,par['hidden'] )
  52. self.load_model_weights(weights)
  53. self.model = self.model.to(self.device)
  54. print('#######load pt model:%s success '%(weights))
  55. self.par['modelType']=self.infer_type
  56. print('#########加载模型:',weights,' 类型:',self.infer_type)
  57. def eval(self,image):
  58. t0 = time.time()
  59. image = self.preprocess_image(image)
  60. t1 = time.time()
  61. if self.infer_type=='pth':
  62. self.model.eval()
  63. preds = self.model(image)
  64. else:
  65. preds,trtstr=OcrTrtForward(self.model,[image],False)
  66. #preds,trtstr=OcrTrtForward(self.model,[image], self.context )
  67. t2 = time.time()
  68. preds_size = torch.IntTensor([preds.size(0)]*1)
  69. _, preds = preds.max(2)
  70. preds = preds.transpose(1, 0).contiguous().view(-1)
  71. res_real = self.converter.decode(preds, preds_size, raw=False)
  72. t3 = time.time()
  73. timeInfos = 'total:%.1f (preProcess:%.1f ,inference:%.1f, postProcess:%.1f) '%( self.get_ms(t3,t0), self.get_ms(t1,t0), self.get_ms(t2,t1), self.get_ms(t3,t2), )
  74. return res_real,timeInfos
  75. def preprocess_image(self,image):
  76. if self.par['nc']==1:
  77. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  78. else: image = image[:,:,::-1] #bgr-->rgb
  79. if self.dynamic:
  80. H,W = image.shape[0:2]
  81. image = cv2.resize(image, (0, 0), fx=self.par['modelSize'][1] / H, fy=self.par['modelSize'][1] / H, interpolation=cv2.INTER_CUBIC)
  82. else:
  83. re_size = self.par['modelSize']
  84. image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
  85. if self.infer_type=='trt':
  86. image = np_resize_keepRation(image,self.par['modelSize'][1] ,self.par['modelSize'][0] )
  87. image = image.astype(np.float32)
  88. image /= 255.0
  89. #print('####line105:',image.shape)
  90. if self.par['nc']==1:
  91. image = (image-self.par['mean'][0])/self.par['std'][0]
  92. image = np.expand_dims(image,0)
  93. else:
  94. image[:, :, 0] -= self.par['mean'][0]
  95. image[:, :, 1] -= self.par['mean'][1]
  96. image[:, :, 2] -= self.par['mean'][2]
  97. image[:, :, 0] /= self.par['std'][0]
  98. image[:, :, 1] /= self.par['std'][1]
  99. image[:, :, 2] /= self.par['std'][2]
  100. image = np.transpose(image, (2, 0, 1))
  101. image = torch.from_numpy(image).float()
  102. image = image.unsqueeze(0)
  103. if self.device != 'cpu':
  104. image = image.to(self.device)
  105. return image
  106. def get_ms(self,t1,t0):
  107. return (t1-t0)*1000.0
  108. def load_model_weights(self,weight):
  109. checkpoint = torch.load(weight)
  110. if 'state_dict' in checkpoint.keys():
  111. self.model.load_state_dict(checkpoint['state_dict'])
  112. else:
  113. try:
  114. self.model.load_state_dict(checkpoint)
  115. except:
  116. ##修正模型参数的名字
  117. state_dict = torch.load(weight)
  118. # create new OrderedDict that does not contain `module.`
  119. from collections import OrderedDict
  120. new_state_dict = OrderedDict()
  121. for k, v in state_dict.items():
  122. name = k[7:] # remove `module.`
  123. new_state_dict[name] = v
  124. # load params
  125. self.model.load_state_dict(new_state_dict)
  126. if __name__== "__main__":
  127. #weights = '/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/weights/scene_base.pth'
  128. weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_ch_2080Ti_fp16_192X32.engine'
  129. par={
  130. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  131. 'char_file':'/home/thsw2/WJ/src/OCR/benchmarking-chinese-text-recognition/src/models/CRNN/data/benchmark.txt',
  132. 'mode':'ch',
  133. 'nc':3,
  134. 'imgH':32,
  135. 'imgW':192,
  136. 'hidden':256,
  137. 'mean':[0.5,0.5,0.5],
  138. 'std':[0.5,0.5,0.5],
  139. 'dynamic':False
  140. }
  141. inputDir = '/home/thsw2/WJ/src/OCR/shipNames'
  142. '''
  143. weights = '/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/crnn_448X32.pth'
  144. #weights = '/mnt/thsw2/DSP2/weights/ocr2/crnn_en_2080Ti_fp16_448X32.engine'
  145. par={
  146. #'cfg':'../AIlib2/weights/conf/OCR_Ch/360CC_config.yaml',
  147. 'char_file':'/home/thsw2/WJ/src/DSP2/AIlib2/weights/conf/ocr2/chars2.txt',
  148. 'mode':'en',
  149. 'nc':1,
  150. 'imgH':32,
  151. 'imgW':448,
  152. 'hidden':256,
  153. 'mean':[0.588,0.588,0.588],
  154. 'std':[0.193,0.193,0.193 ],
  155. 'dynamic':True
  156. }
  157. inputDir='/home/thsw2/WJ/src/DSP2/AIdemo2/images/ocr_en'
  158. '''
  159. model = ocrModel(weights=weights,par=par )
  160. imgUrls = glob.glob('%s/*.jpg'%(inputDir))
  161. for imgUrl in imgUrls[0:]:
  162. img = cv2.imread(imgUrl)
  163. res_real,timeInfos = model.eval(img)
  164. res_real="".join( list(filter(lambda x:(ord(x) >19968 and ord(x)<63865 ) or (ord(x) >47 and ord(x)<58 ),res_real)))
  165. print(res_real,os.path.basename(imgUrl),timeInfos )