You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

411 lines
14KB

  1. # OBSS SAHI Tool
  2. # Code written by Fatih C Akyon, 2020.
  3. import time
  4. from typing import Dict, List, Optional, Union, Tuple
  5. import numpy as np
  6. import requests
  7. from PIL import Image
  8. from numpy import ndarray
  9. def read_image_as_pil(image: Union[Image.Image, str, np.ndarray]):
  10. """
  11. Loads an image as PIL.Image.Image.
  12. Args:
  13. image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
  14. """
  15. # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
  16. Image.MAX_IMAGE_PIXELS = None
  17. if isinstance(image, Image.Image):
  18. image_pil = image
  19. elif isinstance(image, str):
  20. # read image if str image path is provided
  21. try:
  22. image_pil = Image.open(
  23. requests.get(image, stream=True).raw if str(image).startswith("http") else image
  24. ).convert("RGB")
  25. except: # handle large/tiff image reading
  26. try:
  27. import skimage.io
  28. except ImportError:
  29. raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
  30. image_sk = skimage.io.imread(image).astype(np.uint8)
  31. if len(image_sk.shape) == 2: # b&w
  32. image_pil = Image.fromarray(image_sk, mode="1")
  33. elif image_sk.shape[2] == 4: # rgba
  34. image_pil = Image.fromarray(image_sk, mode="RGBA")
  35. elif image_sk.shape[2] == 3: # rgb
  36. image_pil = Image.fromarray(image_sk, mode="RGB")
  37. else:
  38. raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
  39. elif isinstance(image, np.ndarray):
  40. if image.shape[0] < 5: # image in CHW
  41. image = image[:, :, ::-1]
  42. image_pil = Image.fromarray(image)
  43. else:
  44. raise TypeError("read image with 'pillow' using 'Image.open()'")
  45. return image_pil
  46. def get_slice_bboxes(
  47. image_height: int,
  48. image_width: int,
  49. slice_height: int = None,
  50. slice_width: int = None,
  51. auto_slice_resolution: bool = True,
  52. overlap_height_ratio: float = 0.2,
  53. overlap_width_ratio: float = 0.2,
  54. ) -> List[List[int]]:
  55. """Slices `image_pil` in crops.
  56. Corner values of each slice will be generated using the `slice_height`,
  57. `slice_width`, `overlap_height_ratio` and `overlap_width_ratio` arguments.
  58. Args:
  59. image_height (int): Height of the original image.
  60. image_width (int): Width of the original image.
  61. slice_height (int): Height of each slice. Default 512.
  62. slice_width (int): Width of each slice. Default 512.
  63. overlap_height_ratio(float): Fractional overlap in height of each
  64. slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
  65. overlap of 20 pixels). Default 0.2.
  66. overlap_width_ratio(float): Fractional overlap in width of each
  67. slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
  68. overlap of 20 pixels). Default 0.2.
  69. auto_slice_resolution (bool): if not set slice parameters such as slice_height and slice_width,
  70. it enables automatically calculate these params from image resolution and orientation.
  71. Returns:
  72. List[List[int]]: List of 4 corner coordinates for each N slices.
  73. [
  74. [slice_0_left, slice_0_top, slice_0_right, slice_0_bottom],
  75. ...
  76. [slice_N_left, slice_N_top, slice_N_right, slice_N_bottom]
  77. ]
  78. """
  79. slice_bboxes = []
  80. y_max = y_min = 0
  81. if slice_height and slice_width:
  82. y_overlap = int(overlap_height_ratio * slice_height)
  83. x_overlap = int(overlap_width_ratio * slice_width)
  84. elif auto_slice_resolution:
  85. x_overlap, y_overlap, slice_width, slice_height = get_auto_slice_params(height=image_height, width=image_width)
  86. else:
  87. raise ValueError("Compute type is not auto and slice width and height are not provided.")
  88. while y_max < image_height:
  89. x_min = x_max = 0
  90. y_max = y_min + slice_height
  91. while x_max < image_width:
  92. x_max = x_min + slice_width
  93. if y_max > image_height or x_max > image_width:
  94. xmax = min(image_width, x_max)
  95. ymax = min(image_height, y_max)
  96. xmin = max(0, xmax - slice_width)
  97. ymin = max(0, ymax - slice_height)
  98. slice_bboxes.append([xmin, ymin, xmax, ymax])
  99. else:
  100. slice_bboxes.append([x_min, y_min, x_max, y_max])
  101. x_min = x_max - x_overlap
  102. y_min = y_max - y_overlap
  103. return slice_bboxes
  104. class SlicedImage:
  105. def __init__(self, image, starting_pixel):
  106. """
  107. image: np.array
  108. Sliced image.
  109. starting_pixel: list of list of int
  110. Starting pixel coordinates of the sliced image.
  111. """
  112. self.image = image
  113. self.starting_pixel = starting_pixel
  114. class SliceImageResult:
  115. def __init__(self, original_image_size=None):
  116. """
  117. sliced_image_list: list of SlicedImage
  118. image_dir: str
  119. Directory of the sliced image exports.
  120. original_image_size: list of int
  121. Size of the unsliced original image in [height, width]
  122. """
  123. self._sliced_image_list: List[SlicedImage] = []
  124. self.original_image_height = original_image_size[0]
  125. self.original_image_width = original_image_size[1]
  126. def add_sliced_image(self, sliced_image: SlicedImage):
  127. if not isinstance(sliced_image, SlicedImage):
  128. raise TypeError("sliced_image must be a SlicedImage instance")
  129. self._sliced_image_list.append(sliced_image)
  130. @property
  131. def sliced_image_list(self):
  132. return self._sliced_image_list
  133. @property
  134. def images(self):
  135. """Returns sliced images.
  136. Returns:
  137. images: a list of np.array
  138. """
  139. images = []
  140. for sliced_image in self._sliced_image_list:
  141. images.append(sliced_image.image)
  142. return images
  143. @property
  144. def starting_pixels(self) -> List[int]:
  145. """Returns a list of starting pixels for each slice.
  146. Returns:
  147. starting_pixels: a list of starting pixel coords [x,y]
  148. """
  149. starting_pixels = []
  150. for sliced_image in self._sliced_image_list:
  151. starting_pixels.append(sliced_image.starting_pixel)
  152. return starting_pixels
  153. def __len__(self):
  154. return len(self._sliced_image_list)
  155. def slice_image(
  156. image: Union[str, Image.Image],
  157. slice_height: int = None,
  158. slice_width: int = None,
  159. overlap_height_ratio: float = None,
  160. overlap_width_ratio: float = None,
  161. auto_slice_resolution: bool = True,
  162. ) -> Tuple[ndarray, ndarray]:
  163. """Slice a large image into smaller windows. If output_file_name is given export
  164. sliced images.
  165. Args:
  166. auto_slice_resolution:
  167. image (str or PIL.Image): File path of image or Pillow Image to be sliced.
  168. coco_annotation_list (CocoAnnotation): List of CocoAnnotation objects.
  169. output_file_name (str, optional): Root name of output files (coordinates will
  170. be appended to this)
  171. output_dir (str, optional): Output directory
  172. slice_height (int): Height of each slice. Default 512.
  173. slice_width (int): Width of each slice. Default 512.
  174. overlap_height_ratio (float): Fractional overlap in height of each
  175. slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
  176. overlap of 20 pixels). Default 0.2.
  177. overlap_width_ratio (float): Fractional overlap in width of each
  178. slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
  179. overlap of 20 pixels). Default 0.2.
  180. min_area_ratio (float): If the cropped annotation area to original annotation
  181. ratio is smaller than this value, the annotation is filtered out. Default 0.1.
  182. out_ext (str, optional): Extension of saved images. Default is the
  183. original suffix.
  184. verbose (bool, optional): Switch to print relevant values to screen.
  185. Default 'False'.
  186. Returns:
  187. sliced_image_result: SliceImageResult:
  188. sliced_image_list: list of SlicedImage
  189. image_dir: str
  190. Directory of the sliced image exports.
  191. original_image_size: list of int
  192. Size of the unsliced original image in [height, width]
  193. num_total_invalid_segmentation: int
  194. Number of invalid segmentation annotations.
  195. """
  196. # read image
  197. image_pil = read_image_as_pil(image)
  198. image_width, image_height = image_pil.size
  199. if not (image_width != 0 and image_height != 0):
  200. raise RuntimeError(f"invalid image size: {image_pil.size} for 'slice_image'.")
  201. slice_bboxes = get_slice_bboxes(
  202. image_height=image_height,
  203. image_width=image_width,
  204. auto_slice_resolution=auto_slice_resolution,
  205. slice_height=slice_height,
  206. slice_width=slice_width,
  207. overlap_height_ratio=overlap_height_ratio,
  208. overlap_width_ratio=overlap_width_ratio,
  209. )
  210. t0 = time.time()
  211. n_ims = 0
  212. # init images and annotations lists
  213. sliced_image_result = SliceImageResult(original_image_size=[image_height, image_width])
  214. image_pil_arr = np.asarray(image_pil)
  215. # iterate over slices
  216. for slice_bbox in slice_bboxes:
  217. n_ims += 1
  218. # extract image
  219. tlx = slice_bbox[0]
  220. tly = slice_bbox[1]
  221. brx = slice_bbox[2]
  222. bry = slice_bbox[3]
  223. image_pil_slice = image_pil_arr[tly:bry, tlx:brx]
  224. # create sliced image and append to sliced_image_result
  225. sliced_image = SlicedImage(
  226. image=image_pil_slice, starting_pixel=[slice_bbox[0], slice_bbox[1]]
  227. )
  228. sliced_image_result.add_sliced_image(sliced_image)
  229. image_numpy = np.array(sliced_image_result.images)
  230. shift_amount = np.array(sliced_image_result.starting_pixels)
  231. return image_numpy, shift_amount
  232. def calc_ratio_and_slice(orientation, slide=1, ratio=0.1):
  233. """
  234. According to image resolution calculation overlap params
  235. Args:
  236. orientation: image capture angle
  237. slide: sliding window
  238. ratio: buffer value
  239. Returns:
  240. overlap params
  241. """
  242. if orientation == "vertical":
  243. slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide * 2, ratio, ratio
  244. elif orientation == "horizontal":
  245. slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide * 2, slide, ratio, ratio
  246. elif orientation == "square":
  247. slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide, ratio, ratio
  248. return slice_row, slice_col, overlap_height_ratio, overlap_width_ratio # noqa
  249. def calc_resolution_factor(resolution: int) -> int:
  250. """
  251. According to image resolution calculate power(2,n) and return the closest smaller `n`.
  252. Args:
  253. resolution: the width and height of the image multiplied. such as 1024x720 = 737280
  254. Returns:
  255. """
  256. expo = 0
  257. while np.power(2, expo) < resolution:
  258. expo += 1
  259. return expo - 1
  260. def calc_aspect_ratio_orientation(width: int, height: int) -> str:
  261. """
  262. Args:
  263. width:
  264. height:
  265. Returns:
  266. image capture orientation
  267. """
  268. if width < height:
  269. return "vertical"
  270. elif width > height:
  271. return "horizontal"
  272. else:
  273. return "square"
  274. def calc_slice_and_overlap_params(resolution: str, height: int, width: int, orientation: str) -> List:
  275. """
  276. This function calculate according to image resolution slice and overlap params.
  277. Args:
  278. resolution: str
  279. height: int
  280. width: int
  281. orientation: str
  282. Returns:
  283. x_overlap, y_overlap, slice_width, slice_height
  284. """
  285. if resolution == "medium":
  286. split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
  287. orientation, slide=1, ratio=0.8
  288. )
  289. elif resolution == "high":
  290. split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
  291. orientation, slide=2, ratio=0.4
  292. )
  293. elif resolution == "ultra-high":
  294. split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
  295. orientation, slide=4, ratio=0.4
  296. )
  297. else: # low condition
  298. split_col = 1
  299. split_row = 1
  300. overlap_width_ratio = 1
  301. overlap_height_ratio = 1
  302. slice_height = height // split_col
  303. slice_width = width // split_row
  304. x_overlap = int(slice_width * overlap_width_ratio)
  305. y_overlap = int(slice_height * overlap_height_ratio)
  306. return x_overlap, y_overlap, slice_width, slice_height # noqa
  307. def get_resolution_selector(res: str, height: int, width: int):
  308. """
  309. Args:
  310. res: resolution of image such as low, medium
  311. height:
  312. width:
  313. Returns:
  314. trigger slicing params function and return overlap params
  315. """
  316. orientation = calc_aspect_ratio_orientation(width=width, height=height)
  317. x_overlap, y_overlap, slice_width, slice_height = calc_slice_and_overlap_params(
  318. resolution=res, height=height, width=width, orientation=orientation
  319. )
  320. return x_overlap, y_overlap, slice_width, slice_height
  321. def get_auto_slice_params(height: int, width: int):
  322. """
  323. According to Image HxW calculate overlap sliding window and buffer params
  324. factor is the power value of 2 closest to the image resolution.
  325. factor <= 18: low resolution image such as 300x300, 640x640
  326. 18 < factor <= 21: medium resolution image such as 1024x1024, 1336x960
  327. 21 < factor <= 24: high resolution image such as 2048x2048, 2048x4096, 4096x4096
  328. factor > 24: ultra-high resolution image such as 6380x6380, 4096x8192
  329. Args:
  330. height:
  331. width:
  332. Returns:
  333. slicing overlap params x_overlap, y_overlap, slice_width, slice_height
  334. """
  335. resolution = height * width
  336. factor = calc_resolution_factor(resolution)
  337. if factor <= 18:
  338. return get_resolution_selector("low", height=height, width=width)
  339. elif 18 <= factor < 21:
  340. return get_resolution_selector("medium", height=height, width=width)
  341. elif 21 <= factor < 24:
  342. return get_resolution_selector("high", height=height, width=width)
  343. else:
  344. return get_resolution_selector("ultra-high", height=height, width=width)