選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

192 行
7.3KB

  1. import torch
  2. import numpy as np
  3. import torchvision.transforms as transforms
  4. import math
  5. from PIL import Image
  6. def custom_mean(x):
  7. return x.prod()**(2.0/np.sqrt(len(x)))
  8. def contrast_grey(img):
  9. high = np.percentile(img, 90)
  10. low = np.percentile(img, 10)
  11. return (high-low)/np.maximum(10, high+low), high, low
  12. def adjust_contrast_grey(img, target = 0.4):
  13. contrast, high, low = contrast_grey(img)
  14. if contrast < target:
  15. img = img.astype(int)
  16. ratio = 200./np.maximum(10, high-low)
  17. img = (img - low + 25)*ratio
  18. img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
  19. return img
  20. class NormalizePAD(object):
  21. def __init__(self, max_size, PAD_type='right'):
  22. self.toTensor = transforms.ToTensor()
  23. self.max_size = max_size
  24. self.max_width_half = math.floor(max_size[2] / 2)
  25. self.PAD_type = PAD_type
  26. def __call__(self, img):
  27. img = self.toTensor(img)
  28. img.sub_(0.5).div_(0.5)
  29. c, h, w = img.size()
  30. Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
  31. Pad_img[:, :, :w] = img # right pad
  32. if self.max_size[2] != w: # add border Pad
  33. Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
  34. return Pad_img
  35. class AlignCollate(object):
  36. def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, adjust_contrast = 0.):
  37. self.imgH = imgH
  38. self.imgW = imgW
  39. self.keep_ratio_with_pad = keep_ratio_with_pad
  40. self.adjust_contrast = adjust_contrast
  41. def __call__(self, batch):
  42. #print('##recongnition.py line72: type(batch[0]):',type(batch[0]),batch[0], )
  43. batch = filter(lambda x: x is not None, batch)
  44. images = batch
  45. resized_max_w = self.imgW
  46. input_channel = 1
  47. transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
  48. resized_images = []
  49. for image in images:
  50. w, h = image.size
  51. #### augmentation here - change contrast
  52. if self.adjust_contrast > 0:
  53. image = np.array(image.convert("L"))
  54. image = adjust_contrast_grey(image, target = self.adjust_contrast)
  55. image = Image.fromarray(image, 'L')
  56. ratio = w / float(h)
  57. if math.ceil(self.imgH * ratio) > self.imgW:
  58. resized_w = self.imgW
  59. else:
  60. resized_w = math.ceil(self.imgH * ratio)
  61. resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
  62. resized_images.append(transform(resized_image))
  63. image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
  64. return image_tensors
  65. class CTCLabelConverter(object):
  66. """ Convert between text-label and text-index """
  67. def __init__(self, character, separator_list = {}, dict_pathlist = {}):
  68. # character (str): set of the possible characters.
  69. dict_character = list(character)
  70. self.dict = {}
  71. for i, char in enumerate(dict_character):
  72. self.dict[char] = i + 1
  73. self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
  74. self.separator_list = separator_list
  75. separator_char = []
  76. for lang, sep in separator_list.items():
  77. separator_char += sep
  78. self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)]
  79. ####### latin dict
  80. if len(separator_list) == 0:
  81. dict_list = []
  82. for lang, dict_path in dict_pathlist.items():
  83. try:
  84. with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
  85. word_count = input_file.read().splitlines()
  86. dict_list += word_count
  87. except:
  88. pass
  89. else:
  90. dict_list = {}
  91. for lang, dict_path in dict_pathlist.items():
  92. with open(dict_path, "r", encoding = "utf-8-sig") as input_file:
  93. word_count = input_file.read().splitlines()
  94. dict_list[lang] = word_count
  95. self.dict_list = dict_list
  96. def encode(self, text, batch_max_length=25):
  97. """convert text-label into text-index.
  98. input:
  99. text: text labels of each image. [batch_size]
  100. output:
  101. text: concatenated text index for CTCLoss.
  102. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  103. length: length of each text. [batch_size]
  104. """
  105. length = [len(s) for s in text]
  106. text = ''.join(text)
  107. text = [self.dict[char] for char in text]
  108. return (torch.IntTensor(text), torch.IntTensor(length))
  109. def decode_greedy(self, text_index, length):
  110. """ convert text-index into text-label. """
  111. texts = []
  112. index = 0
  113. for l in length:
  114. t = text_index[index:index + l]
  115. # Returns a boolean array where true is when the value is not repeated
  116. a = np.insert(~((t[1:]==t[:-1])),0,True)
  117. # Returns a boolean array where true is when the value is not in the ignore_idx list
  118. b = ~np.isin(t,np.array(self.ignore_idx))
  119. # Combine the two boolean array
  120. c = a & b
  121. # Gets the corresponding character according to the saved indexes
  122. text = ''.join(np.array(self.character)[t[c.nonzero()]])
  123. texts.append(text)
  124. index += l
  125. return texts
  126. def decode_beamsearch(self, mat, beamWidth=5):
  127. texts = []
  128. for i in range(mat.shape[0]):
  129. t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth)
  130. texts.append(t)
  131. return texts
  132. def decode_wordbeamsearch(self, mat, beamWidth=5):
  133. texts = []
  134. argmax = np.argmax(mat, axis = 2)
  135. for i in range(mat.shape[0]):
  136. string = ''
  137. # without separators - use space as separator
  138. if len(self.separator_list) == 0:
  139. space_idx = self.dict[' ']
  140. data = np.argwhere(argmax[i]!=space_idx).flatten()
  141. group = np.split(data, np.where(np.diff(data) != 1)[0]+1)
  142. group = [ list(item) for item in group if len(item)>0]
  143. for j, list_idx in enumerate(group):
  144. matrix = mat[i, list_idx,:]
  145. t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None,\
  146. beamWidth=beamWidth, dict_list=self.dict_list)
  147. if j == 0: string += t
  148. else: string += ' '+t
  149. # with separators
  150. else:
  151. words = word_segmentation(argmax[i])
  152. for word in words:
  153. matrix = mat[i, word[1][0]:word[1][1]+1,:]
  154. if word[0] == '': dict_list = []
  155. else: dict_list = self.dict_list[word[0]]
  156. t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list)
  157. string += t
  158. texts.append(string)
  159. return texts