You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

315 lines
11KB

  1. import torch
  2. import numpy as np
  3. import torchvision.transforms as transforms
  4. import math, yaml
  5. from easydict import EasyDict as edict
  6. from PIL import Image
  7. import cv2
  8. from torch.autograd import Variable
  9. import time
  10. import tensorrt as trt
  11. def trt_version():
  12. return trt.__version__
  13. def torch_device_from_trt(device):
  14. if device == trt.TensorLocation.DEVICE:
  15. return torch.device("cuda")
  16. elif device == trt.TensorLocation.HOST:
  17. return torch.device("cpu")
  18. else:
  19. return TypeError("%s is not supported by torch" % device)
  20. def torch_dtype_from_trt(dtype):
  21. if dtype == trt.int8:
  22. return torch.int8
  23. elif trt_version() >= '7.0' and dtype == trt.bool:
  24. return torch.bool
  25. elif dtype == trt.int32:
  26. return torch.int32
  27. elif dtype == trt.float16:
  28. return torch.float16
  29. elif dtype == trt.float32:
  30. return torch.float32
  31. else:
  32. raise TypeError("%s is not supported by torch" % dtype)
  33. def OcrTrtForward(engine,inputs,contextFlag=False):
  34. t0=time.time()
  35. #with engine.create_execution_context() as context:
  36. if not contextFlag: context = engine.create_execution_context()
  37. else: context=contextFlag
  38. namess=[ engine.get_tensor_name(index) for index in range(engine.num_bindings) ]
  39. input_names = [namess[0]];output_names=namess[1:]
  40. batch_size = inputs[0].shape[0]
  41. bindings = [None] * (len(input_names) + len(output_names))
  42. t1=time.time()
  43. # 创建输出tensor,并分配内存
  44. outputs = [None] * len(output_names)
  45. for i, output_name in enumerate(output_names):
  46. idx = engine.get_binding_index(output_name)#通过binding_name找到对应的input_id
  47. dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))#找到对应的数据类型
  48. shape = (batch_size,) + tuple(engine.get_binding_shape(idx))#找到对应的形状大小
  49. device = torch_device_from_trt(engine.get_location(idx))
  50. output = torch.empty(size=shape, dtype=dtype, device=device)
  51. #print('&'*10,'device:',device,'idx:',idx,'shape:',shape,'dtype:',dtype,' device:',output.get_device())
  52. outputs[i] = output
  53. #print('###line65:',output_name,i,idx,dtype,shape)
  54. bindings[idx] = output.data_ptr()#绑定输出数据指针
  55. t2=time.time()
  56. for i, input_name in enumerate(input_names):
  57. idx =engine.get_binding_index(input_name)
  58. bindings[idx] = inputs[0].contiguous().data_ptr()#应当为inputs[i],对应3个输入。但由于我们使用的是单张图片,所以将3个输入全设置为相同的图片。
  59. #print('#'*10,'input_names:,', input_name,'idx:',idx, inputs[0].dtype,', inputs[0] device:',inputs[0].get_device())
  60. t3=time.time()
  61. context.execute_v2(bindings) # 执行推理
  62. t4=time.time()
  63. if len(outputs) == 1:
  64. outputs = outputs[0]
  65. outstr='create Context:%.2f alloc memory:%.2f prepare input:%.2f conext infer:%.2f, total:%.2f'%((t1-t0 )*1000 , (t2-t1)*1000,(t3-t2)*1000,(t4-t3)*1000, (t4-t0)*1000 )
  66. return outputs[0],outstr
  67. def np_resize_keepRation(img,inp_h, inp_w):
  68. #print(img.shape,inp_h,inp_w)
  69. img_h, img_w = img.shape[0:2]
  70. fy=inp_h/img_h
  71. keep_w = int(img_w* fy )
  72. Rsize=( keep_w , img_h)
  73. img = cv2.resize(img, Rsize )
  74. #resize后是120,max是160,120-160的地方用边界的值填充
  75. if keep_w < inp_w:
  76. if len(img.shape)==3:
  77. img_out = np.zeros((inp_h, inp_w,3 ),dtype=np.uint8)
  78. img_out[:,:keep_w]=img[:,:]
  79. for j in range(3):
  80. img_out[:,keep_w:,j] = np.tile(img[:,keep_w-1:,j], inp_w-keep_w)
  81. else:
  82. img_out = np.zeros((inp_h, inp_w ),dtype=np.uint8)
  83. img_out[:,:keep_w]=img[:,:]
  84. img_out[:,keep_w:] = np.tile(img[:,keep_w-1:], inp_w-keep_w)
  85. else:
  86. img_out = cv2.resize(img,(inp_w,inp_h))
  87. return img_out
  88. def recognition_ocr(config, img, model, converter, device,par={}):
  89. model_mode=par['model_mode'];contextFlag=par['contextFlag']
  90. if len(img.shape)==3:
  91. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  92. # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
  93. h, w = img.shape
  94. # fisrt step: resize the height and width of image to (32, x)
  95. img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
  96. if model_mode=='trt':
  97. img = np_resize_keepRation(img,par['imgH'], par['imgW'])
  98. img = np.expand_dims(img,axis=2)
  99. # normalize
  100. img = img.astype(np.float32)
  101. img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
  102. img = img.transpose([2, 0, 1])
  103. img = torch.from_numpy(img)
  104. img = img.to(device)
  105. img = img.view(1, *img.size())
  106. if model_mode=='trt':
  107. img_input = img.to('cuda:0')
  108. time2 = time.time()
  109. preds,trtstr=OcrTrtForward(model,[img],contextFlag)
  110. else:
  111. model.eval()
  112. preds = model(img)
  113. _, preds = preds.max(2)
  114. preds = preds.transpose(1, 0).contiguous().view(-1)
  115. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  116. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  117. return sim_pred
  118. class strLabelConverter(object):
  119. """Convert between str and label.
  120. NOTE:
  121. Insert `blank` to the alphabet for CTC.
  122. Args:
  123. alphabet (str): set of the possible characters.
  124. ignore_case (bool, default=True): whether or not to ignore all of the case.
  125. """
  126. def __init__(self, alphabet, ignore_case=False):
  127. self._ignore_case = ignore_case
  128. if self._ignore_case:
  129. alphabet = alphabet.lower()
  130. self.alphabet = alphabet + '-' # for `-1` index
  131. self.dict = {}
  132. for i, char in enumerate(alphabet):
  133. # NOTE: 0 is reserved for 'blank' required by wrap_ctc
  134. self.dict[char] = i + 1
  135. def encode(self, text):
  136. """Support batch or single str.
  137. Args:
  138. text (str or list of str): texts to convert.
  139. Returns:
  140. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  141. torch.IntTensor [n]: length of each text.
  142. """
  143. length = []
  144. result = []
  145. decode_flag = True if type(text[0])==bytes else False
  146. for item in text:
  147. if decode_flag:
  148. item = item.decode('utf-8','strict')
  149. length.append(len(item))
  150. for char in item:
  151. index = self.dict[char]
  152. result.append(index)
  153. text = result
  154. return (torch.IntTensor(text), torch.IntTensor(length))
  155. def decode(self, t, length, raw=False):
  156. """Decode encoded texts back into strs.
  157. Args:
  158. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  159. torch.IntTensor [n]: length of each text.
  160. Raises:
  161. AssertionError: when the texts and its length does not match.
  162. Returns:
  163. text (str or list of str): texts to convert.
  164. """
  165. if length.numel() == 1:
  166. length = length[0]
  167. assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
  168. if raw:
  169. return ''.join([self.alphabet[i - 1] for i in t])
  170. else:
  171. char_list = []
  172. for i in range(length):
  173. if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
  174. char_list.append(self.alphabet[t[i] - 1])
  175. return ''.join(char_list)
  176. else:
  177. # batch mode
  178. assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
  179. texts = []
  180. index = 0
  181. for i in range(length.numel()):
  182. l = length[i]
  183. texts.append(
  184. self.decode(
  185. t[index:index + l], torch.IntTensor([l]), raw=raw))
  186. index += l
  187. return texts
  188. def get_alphabets(txtfile ):
  189. print(txtfile)
  190. with open(txtfile,'r') as fp:
  191. lines=fp.readlines()
  192. alphas=[x.strip() for x in lines]
  193. return "".join(alphas)
  194. def get_cfg(cfg,char_file):
  195. with open(cfg, 'r') as f:
  196. #config = yaml.load(f)
  197. config = yaml.load(f, Loader=yaml.FullLoader)
  198. config = edict(config)
  199. config.DATASET.ALPHABETS = get_alphabets(char_file.strip() )
  200. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  201. return config
  202. def custom_mean(x):
  203. return x.prod()**(2.0/np.sqrt(len(x)))
  204. def contrast_grey(img):
  205. high = np.percentile(img, 90)
  206. low = np.percentile(img, 10)
  207. return (high-low)/np.maximum(10, high+low), high, low
  208. def adjust_contrast_grey(img, target = 0.4):
  209. contrast, high, low = contrast_grey(img)
  210. if contrast < target:
  211. img = img.astype(int)
  212. ratio = 200./np.maximum(10, high-low)
  213. img = (img - low + 25)*ratio
  214. img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
  215. return img
  216. class NormalizePAD(object):
  217. def __init__(self, max_size, PAD_type='right'):
  218. self.toTensor = transforms.ToTensor()
  219. self.max_size = max_size
  220. self.max_width_half = math.floor(max_size[2] / 2)
  221. self.PAD_type = PAD_type
  222. def __call__(self, img):
  223. img = self.toTensor(img)
  224. img.sub_(0.5).div_(0.5)
  225. c, h, w = img.size()
  226. Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
  227. Pad_img[:, :, :w] = img # right pad
  228. if self.max_size[2] != w: # add border Pad
  229. Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
  230. return Pad_img
  231. class AlignCollate(object):
  232. def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
  233. self.imgH = imgH
  234. self.imgW = imgW
  235. self.keep_ratio_with_pad = keep_ratio_with_pad
  236. self.adjust_contrast = adjust_contrast
  237. def __call__(self, batch):
  238. #print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
  239. batch = filter(lambda x: x is not None, batch)
  240. images = batch
  241. resized_max_w = self.imgW
  242. input_channel = 1
  243. transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
  244. resized_images = []
  245. for image in images:
  246. w, h = image.size
  247. #### augmentation here - change contrast
  248. if self.adjust_contrast > 0:
  249. image = np.array(image.convert("L"))
  250. image = adjust_contrast_grey(image, target = self.adjust_contrast)
  251. image = Image.fromarray(image, 'L')
  252. ratio = w / float(h)
  253. if math.ceil(self.imgH * ratio) > self.imgW:
  254. resized_w = self.imgW
  255. else:
  256. resized_w = math.ceil(self.imgH * ratio)
  257. resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
  258. resized_images.append(transform(resized_image))
  259. image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
  260. return image_tensors