You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 line
11KB

  1. import torch
  2. import numpy as np
  3. import torchvision.transforms as transforms
  4. import math, yaml
  5. from easydict import EasyDict as edict
  6. from PIL import Image
  7. import cv2
  8. from torch.autograd import Variable
  9. import time
  10. import tensorrt as trt
  11. def trt_version():
  12. return trt.__version__
  13. def torch_device_from_trt(device):
  14. if device == trt.TensorLocation.DEVICE:
  15. return torch.device("cuda")
  16. elif device == trt.TensorLocation.HOST:
  17. return torch.device("cpu")
  18. else:
  19. return TypeError("%s is not supported by torch" % device)
  20. def torch_dtype_from_trt(dtype):
  21. if dtype == trt.int8:
  22. return torch.int8
  23. elif trt_version() >= '7.0' and dtype == trt.bool:
  24. return torch.bool
  25. elif dtype == trt.int32:
  26. return torch.int32
  27. elif dtype == trt.float16:
  28. return torch.float16
  29. elif dtype == trt.float32:
  30. return torch.float32
  31. else:
  32. raise TypeError("%s is not supported by torch" % dtype)
  33. def OcrTrtForward(engine,inputs,contextFlag=False):
  34. t0=time.time()
  35. #with engine.create_execution_context() as context:
  36. if not contextFlag: context = engine.create_execution_context()
  37. else: context=contextFlag
  38. namess=[ engine.get_tensor_name(index) for index in range(engine.num_bindings) ]
  39. input_names = [namess[0]];output_names=namess[1:]
  40. batch_size = inputs[0].shape[0]
  41. bindings = [None] * (len(input_names) + len(output_names))
  42. t1=time.time()
  43. # 创建输出tensor,并分配内存
  44. outputs = [None] * len(output_names)
  45. for i, output_name in enumerate(output_names):
  46. idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
  47. dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
  48. shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
  49. device = torch_device_from_trt(engine.get_location(idx))
  50. output = torch.empty(size=shape, dtype=dtype, device=device)
  51. #print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
  52. outputs[i] = output
  53. #print('###line65:',output_name,i,idx,dtype,shape)
  54. bindings[idx] = output.data_ptr()#绑定输出数据指针
  55. t2=time.time()
  56. for i, input_name in enumerate(input_names):
  57. idx =engine.get_binding_index(input_name)
  58. bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
  59. #print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
  60. t3=time.time()
  61. context.execute_v2(bindings) # 执行推理
  62. t4=time.time()
  63. if len(outputs) == 1:
  64. outputs = outputs[0]
  65. outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
  66. return outputs[0],outstr
  67. def np_resize_keepRation(img,inp_h, inp_w):
  68. img_h, img_w = img.shape
  69. fy=inp_h/img_h
  70. keep_w = int(img_w* fy )
  71. Rsize=( keep_w , img_h)
  72. img = cv2.resize(img, Rsize )
  73. #resize后是120,max是160,120-160的地方用边界的值填充
  74. if keep_w < inp_w:
  75. img_out = np.zeros((inp_h, inp_w ),dtype=np.uint8)
  76. img_out[:,:keep_w]=img[:,:]
  77. img_out[:,keep_w:] = np.tile(img[:,keep_w-1:], inp_w-keep_w)
  78. else:
  79. img_out = cv2.resize(img,(inp_w,inp_h))
  80. return img_out
  81. def recognition_ocr(config, img, model, converter, device,par={}):
  82. model_mode=par['model_mode'];contextFlag=par['contextFlag']
  83. if len(img.shape)==3:
  84. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  85. # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
  86. h, w = img.shape
  87. # fisrt step: resize the height and width of image to (32, x)
  88. img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
  89. if model_mode=='trt':
  90. img = np_resize_keepRation(img,par['imgH'], par['imgW'])
  91. img = np.expand_dims(img,axis=2)
  92. # normalize
  93. img = img.astype(np.float32)
  94. img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
  95. img = img.transpose([2, 0, 1])
  96. img = torch.from_numpy(img)
  97. img = img.to(device)
  98. img = img.view(1, *img.size())
  99. if model_mode=='trt':
  100. img_input = img.to('cuda:0')
  101. time2 = time.time()
  102. preds,trtstr=OcrTrtForward(model,[img],contextFlag)
  103. else:
  104. model.eval()
  105. preds = model(img)
  106. _, preds = preds.max(2)
  107. preds = preds.transpose(1, 0).contiguous().view(-1)
  108. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  109. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  110. return sim_pred
  111. class strLabelConverter(object):
  112. """Convert between str and label.
  113. NOTE:
  114. Insert `blank` to the alphabet for CTC.
  115. Args:
  116. alphabet (str): set of the possible characters.
  117. ignore_case (bool, default=True): whether or not to ignore all of the case.
  118. """
  119. def __init__(self, alphabet, ignore_case=False):
  120. self._ignore_case = ignore_case
  121. if self._ignore_case:
  122. alphabet = alphabet.lower()
  123. self.alphabet = alphabet + '-' # for `-1` index
  124. self.dict = {}
  125. for i, char in enumerate(alphabet):
  126. # NOTE: 0 is reserved for 'blank' required by wrap_ctc
  127. self.dict[char] = i + 1
  128. def encode(self, text):
  129. """Support batch or single str.
  130. Args:
  131. text (str or list of str): texts to convert.
  132. Returns:
  133. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  134. torch.IntTensor [n]: length of each text.
  135. """
  136. length = []
  137. result = []
  138. decode_flag = True if type(text[0])==bytes else False
  139. for item in text:
  140. if decode_flag:
  141. item = item.decode('utf-8','strict')
  142. length.append(len(item))
  143. for char in item:
  144. index = self.dict[char]
  145. result.append(index)
  146. text = result
  147. return (torch.IntTensor(text), torch.IntTensor(length))
  148. def decode(self, t, length, raw=False):
  149. """Decode encoded texts back into strs.
  150. Args:
  151. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  152. torch.IntTensor [n]: length of each text.
  153. Raises:
  154. AssertionError: when the texts and its length does not match.
  155. Returns:
  156. text (str or list of str): texts to convert.
  157. """
  158. if length.numel() == 1:
  159. length = length[0]
  160. assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
  161. if raw:
  162. return ''.join([self.alphabet[i - 1] for i in t])
  163. else:
  164. char_list = []
  165. for i in range(length):
  166. if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
  167. char_list.append(self.alphabet[t[i] - 1])
  168. return ''.join(char_list)
  169. else:
  170. # batch mode
  171. assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
  172. texts = []
  173. index = 0
  174. for i in range(length.numel()):
  175. l = length[i]
  176. texts.append(
  177. self.decode(
  178. t[index:index + l], torch.IntTensor([l]), raw=raw))
  179. index += l
  180. return texts
  181. def get_alphabets(txtfile ):
  182. print(txtfile)
  183. with open(txtfile,'r') as fp:
  184. lines=fp.readlines()
  185. alphas=[x.strip() for x in lines]
  186. return "".join(alphas)
  187. def get_cfg(cfg,char_file):
  188. with open(cfg, 'r') as f:
  189. #config = yaml.load(f)
  190. config = yaml.load(f, Loader=yaml.FullLoader)
  191. config = edict(config)
  192. config.DATASET.ALPHABETS = get_alphabets(char_file.strip() )
  193. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  194. return config
  195. def custom_mean(x):
  196. return x.prod()**(2.0/np.sqrt(len(x)))
  197. def contrast_grey(img):
  198. high = np.percentile(img, 90)
  199. low = np.percentile(img, 10)
  200. return (high-low)/np.maximum(10, high+low), high, low
  201. def adjust_contrast_grey(img, target = 0.4):
  202. contrast, high, low = contrast_grey(img)
  203. if contrast < target:
  204. img = img.astype(int)
  205. ratio = 200./np.maximum(10, high-low)
  206. img = (img - low + 25)*ratio
  207. img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
  208. return img
  209. class NormalizePAD(object):
  210. def __init__(self, max_size, PAD_type='right'):
  211. self.toTensor = transforms.ToTensor()
  212. self.max_size = max_size
  213. self.max_width_half = math.floor(max_size[2] / 2)
  214. self.PAD_type = PAD_type
  215. def __call__(self, img):
  216. img = self.toTensor(img)
  217. img.sub_(0.5).div_(0.5)
  218. c, h, w = img.size()
  219. Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
  220. Pad_img[:, :, :w] = img # right pad
  221. if self.max_size[2] != w: # add border Pad
  222. Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
  223. return Pad_img
  224. class AlignCollate(object):
  225. def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
  226. self.imgH = imgH
  227. self.imgW = imgW
  228. self.keep_ratio_with_pad = keep_ratio_with_pad
  229. self.adjust_contrast = adjust_contrast
  230. def __call__(self, batch):
  231. #print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
  232. batch = filter(lambda x: x is not None, batch)
  233. images = batch
  234. resized_max_w = self.imgW
  235. input_channel = 1
  236. transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
  237. resized_images = []
  238. for image in images:
  239. w, h = image.size
  240. #### augmentation here - change contrast
  241. if self.adjust_contrast > 0:
  242. image = np.array(image.convert("L"))
  243. image = adjust_contrast_grey(image, target = self.adjust_contrast)
  244. image = Image.fromarray(image, 'L')
  245. ratio = w / float(h)
  246. if math.ceil(self.imgH * ratio) > self.imgW:
  247. resized_w = self.imgW
  248. else:
  249. resized_w = math.ceil(self.imgH * ratio)
  250. resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
  251. resized_images.append(transform(resized_image))
  252. image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
  253. return image_tensors