Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

192 rindas
7.3KB

  1. import torch
  2. import numpy as np
  3. import torchvision.transforms as transforms
  4. import math
  5. from PIL import Image
  6. def custom_mean(x):
  7. return x.prod()**(2.0/np.sqrt(len(x)))
  8. def contrast_grey(img):
  9. high = np.percentile(img, 90)
  10. low = np.percentile(img, 10)
  11. return (high-low)/np.maximum(10, high+low), high, low
  12. def adjust_contrast_grey(img, target = 0.4):
  13. contrast, high, low = contrast_grey(img)
  14. if contrast < target:
  15. img = img.astype(int)
  16. ratio = 200./np.maximum(10, high-low)
  17. img = (img - low + 25)*ratio
  18. img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
  19. return img
  20. class NormalizePAD(object):
  21. def __init__(self, max_size, PAD_type='right'):
  22. self.toTensor = transforms.ToTensor()
  23. self.max_size = max_size
  24. self.max_width_half = math.floor(max_size[2] / 2)
  25. self.PAD_type = PAD_type
  26. def __call__(self, img):
  27. img = self.toTensor(img)
  28. img.sub_(0.5).div_(0.5)
  29. c, h, w = img.size()
  30. Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
  31. Pad_img[:, :, :w] = img # right pad
  32. if self.max_size[2] != w: # add border Pad
  33. Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
  34. return Pad_img
  35. class AlignCollate(object):
  36. def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
  37. self.imgH = imgH
  38. self.imgW = imgW
  39. self.keep_ratio_with_pad = keep_ratio_with_pad
  40. self.adjust_contrast = adjust_contrast
  41. def __call__(self, batch):
  42. #print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
  43. batch = filter(lambda x: x is not None, batch)
  44. images = batch
  45. resized_max_w = self.imgW
  46. input_channel = 1
  47. transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
  48. resized_images = []
  49. for image in images:
  50. w, h = image.size
  51. #### augmentation here - change contrast
  52. if self.adjust_contrast > 0:
  53. image = np.array(image.convert("L"))
  54. image = adjust_contrast_grey(image, target = self.adjust_contrast)
  55. image = Image.fromarray(image, 'L')
  56. ratio = w / float(h)
  57. if math.ceil(self.imgH * ratio) > self.imgW:
  58. resized_w = self.imgW
  59. else:
  60. resized_w = math.ceil(self.imgH * ratio)
  61. resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
  62. resized_images.append(transform(resized_image))
  63. image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
  64. return image_tensors
  65. class CTCLabelConverter(object):
  66. """ Convert between text-label and text-index """
  67. def __init__(self, character, separator_list = {}, dict_pathlist = {}):
  68. # character (str): set of the possible characters.
  69. dict_character = list(character)
  70. self.dict = {}
  71. for i, char in enumerate(dict_character):
  72. self.dict[char] = i + 1
  73. self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
  74. self.separator_list = separator_list
  75. separator_char = []
  76. for lang, sep in separator_list.items():
  77. separator_char += sep
  78. self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
  79. ####### latin dict
  80. if len(separator_list) == 0:
  81. dict_list = []
  82. for lang, dict_path in dict_pathlist.items():
  83. try:
  84. with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
  85. word_count = input_file.read().splitlines()
  86. dict_list += word_count
  87. except:
  88. pass
  89. else:
  90. dict_list = {}
  91. for lang, dict_path in dict_pathlist.items():
  92. with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
  93. word_count = input_file.read().splitlines()
  94. dict_list[lang] = word_count
  95. self.dict_list = dict_list
  96. def encode(self, text, batch_max_length=25):
  97. """convert text-label into text-index.
  98. input:
  99. text: text labels of each image. [batch_size]
  100. output:
  101. text: concatenated text index for CTCLoss.
  102. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  103. length: length of each text. [batch_size]
  104. """
  105. length = [len(s) for s in text]
  106. text = ''.join(text)
  107. text = [self.dict[char] for char in text]
  108. return (torch.IntTensor(text), torch.IntTensor(length))
  109. def decode_greedy(self, text_index, length):
  110. """ convert text-index into text-label. """
  111. texts = []
  112. index = 0
  113. for l in length:
  114. t = text_index[index:index + l]
  115. # Returns a boolean array where true is when the value is not repeated
  116. a = np.insert(~((t[1:]==t[:-1])),0,True)
  117. # Returns a boolean array where true is when the value is not in the ignore_idx list
  118. b = ~np.isin(t,np.array(self.ignore_idx))
  119. # Combine the two boolean array
  120. c = a & b
  121. # Gets the corresponding character according to the saved indexes
  122. text = ''.join(np.array(self.character)[t[c.nonzero()]])
  123. texts.append(text)
  124. index += l
  125. return texts
  126. def decode_beamsearch(self, mat, beamWidth=5):
  127. texts = []
  128. for i in range(mat.shape[0]):
  129. t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
  130. texts.append(t)
  131. return texts
  132. def decode_wordbeamsearch(self, mat, beamWidth=5):
  133. texts = []
  134. argmax = np.argmax(mat, axis = 2)
  135. for i in range(mat.shape[0]):
  136. string = ''
  137. # without separators - use space as separator
  138. if len(self.separator_list) == 0:
  139. space_idx = self.dict[' ']
  140. data = np.argwhere(argmax[i]!=space_idx).flatten()
  141. group = np.split(data, np.where(np.diff(data) != 1)[0]+1)
  142. group = [ list(item) for item in group if len(item)>0]
  143. for j, list_idx in enumerate(group):
  144. matrix = mat[i, list_idx,:]
  145. t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None,\
  146. beamWidth=beamWidth, dict_list=self.dict_list)
  147. if j == 0: string += t
  148. else: string += ' '+t
  149. # with separators
  150. else:
  151. words = word_segmentation(argmax[i])
  152. for word in words:
  153. matrix = mat[i, word[1][0]:word[1][1]+1,:]
  154. if word[0] == '': dict_list = []
  155. else: dict_list = self.dict_list[word[0]]
  156. t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
  157. string += t
  158. texts.append(string)
  159. return texts