Browse Source

add

pinjie.py
sliceing.py
crop_pinjie
Administrator 1 year ago
parent
commit
9ada63cc3d
2 changed files with 432 additions and 0 deletions
  1. +22
    -0
      utils/pinjie.py
  2. +410
    -0
      utils/sliceing.py

+ 22
- 0
utils/pinjie.py View File

@@ -0,0 +1,22 @@
import logging
import os

import torch

logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
)


def get_pinjie(img, shift):
nbox = img.shape[1]
shift = torch.from_numpy(shift).to(img.device)
shift = shift.unsqueeze(1).repeat(1, nbox, 1)
img[..., :2] += shift

img_out = img.view(1, -1, 7)

return img_out

+ 410
- 0
utils/sliceing.py View File

@@ -0,0 +1,410 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

import time
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import requests
from PIL import Image
from numpy import ndarray


def read_image_as_pil(image: Union[Image.Image, str, np.ndarray]):
"""
Loads an image as PIL.Image.Image.

Args:
image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
"""
# https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
Image.MAX_IMAGE_PIXELS = None

if isinstance(image, Image.Image):
image_pil = image
elif isinstance(image, str):
# read image if str image path is provided
try:
image_pil = Image.open(
requests.get(image, stream=True).raw if str(image).startswith("http") else image
).convert("RGB")
except: # handle large/tiff image reading
try:
import skimage.io
except ImportError:
raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
image_sk = skimage.io.imread(image).astype(np.uint8)
if len(image_sk.shape) == 2: # b&w
image_pil = Image.fromarray(image_sk, mode="1")
elif image_sk.shape[2] == 4: # rgba
image_pil = Image.fromarray(image_sk, mode="RGBA")
elif image_sk.shape[2] == 3: # rgb
image_pil = Image.fromarray(image_sk, mode="RGB")
else:
raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
elif isinstance(image, np.ndarray):
if image.shape[0] < 5: # image in CHW
image = image[:, :, ::-1]
image_pil = Image.fromarray(image)
else:
raise TypeError("read image with 'pillow' using 'Image.open()'")
return image_pil

def get_slice_bboxes(
image_height: int,
image_width: int,
slice_height: int = None,
slice_width: int = None,
auto_slice_resolution: bool = True,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
) -> List[List[int]]:
"""Slices `image_pil` in crops.
Corner values of each slice will be generated using the `slice_height`,
`slice_width`, `overlap_height_ratio` and `overlap_width_ratio` arguments.

Args:
image_height (int): Height of the original image.
image_width (int): Width of the original image.
slice_height (int): Height of each slice. Default 512.
slice_width (int): Width of each slice. Default 512.
overlap_height_ratio(float): Fractional overlap in height of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
overlap_width_ratio(float): Fractional overlap in width of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
auto_slice_resolution (bool): if not set slice parameters such as slice_height and slice_width,
it enables automatically calculate these params from image resolution and orientation.

Returns:
List[List[int]]: List of 4 corner coordinates for each N slices.
[
[slice_0_left, slice_0_top, slice_0_right, slice_0_bottom],
...
[slice_N_left, slice_N_top, slice_N_right, slice_N_bottom]
]
"""
slice_bboxes = []
y_max = y_min = 0

if slice_height and slice_width:
y_overlap = int(overlap_height_ratio * slice_height)
x_overlap = int(overlap_width_ratio * slice_width)
elif auto_slice_resolution:
x_overlap, y_overlap, slice_width, slice_height = get_auto_slice_params(height=image_height, width=image_width)
else:
raise ValueError("Compute type is not auto and slice width and height are not provided.")

while y_max < image_height:
x_min = x_max = 0
y_max = y_min + slice_height
while x_max < image_width:
x_max = x_min + slice_width
if y_max > image_height or x_max > image_width:
xmax = min(image_width, x_max)
ymax = min(image_height, y_max)
xmin = max(0, xmax - slice_width)
ymin = max(0, ymax - slice_height)
slice_bboxes.append([xmin, ymin, xmax, ymax])
else:
slice_bboxes.append([x_min, y_min, x_max, y_max])
x_min = x_max - x_overlap
y_min = y_max - y_overlap
return slice_bboxes


class SlicedImage:
def __init__(self, image, starting_pixel):
"""
image: np.array
Sliced image.
starting_pixel: list of list of int
Starting pixel coordinates of the sliced image.
"""
self.image = image
self.starting_pixel = starting_pixel


class SliceImageResult:
def __init__(self, original_image_size=None):
"""
sliced_image_list: list of SlicedImage
image_dir: str
Directory of the sliced image exports.
original_image_size: list of int
Size of the unsliced original image in [height, width]
"""
self._sliced_image_list: List[SlicedImage] = []
self.original_image_height = original_image_size[0]
self.original_image_width = original_image_size[1]

def add_sliced_image(self, sliced_image: SlicedImage):
if not isinstance(sliced_image, SlicedImage):
raise TypeError("sliced_image must be a SlicedImage instance")

self._sliced_image_list.append(sliced_image)

@property
def sliced_image_list(self):
return self._sliced_image_list

@property
def images(self):
"""Returns sliced images.

Returns:
images: a list of np.array
"""
images = []
for sliced_image in self._sliced_image_list:
images.append(sliced_image.image)
return images

@property
def starting_pixels(self) -> List[int]:
"""Returns a list of starting pixels for each slice.

Returns:
starting_pixels: a list of starting pixel coords [x,y]
"""
starting_pixels = []
for sliced_image in self._sliced_image_list:
starting_pixels.append(sliced_image.starting_pixel)
return starting_pixels

def __len__(self):
return len(self._sliced_image_list)


def slice_image(
image: Union[str, Image.Image],
slice_height: int = None,
slice_width: int = None,
overlap_height_ratio: float = None,
overlap_width_ratio: float = None,
auto_slice_resolution: bool = True,
) -> Tuple[ndarray, ndarray]:
"""Slice a large image into smaller windows. If output_file_name is given export
sliced images.

Args:
auto_slice_resolution:
image (str or PIL.Image): File path of image or Pillow Image to be sliced.
coco_annotation_list (CocoAnnotation): List of CocoAnnotation objects.
output_file_name (str, optional): Root name of output files (coordinates will
be appended to this)
output_dir (str, optional): Output directory
slice_height (int): Height of each slice. Default 512.
slice_width (int): Width of each slice. Default 512.
overlap_height_ratio (float): Fractional overlap in height of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
overlap_width_ratio (float): Fractional overlap in width of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
min_area_ratio (float): If the cropped annotation area to original annotation
ratio is smaller than this value, the annotation is filtered out. Default 0.1.
out_ext (str, optional): Extension of saved images. Default is the
original suffix.
verbose (bool, optional): Switch to print relevant values to screen.
Default 'False'.

Returns:
sliced_image_result: SliceImageResult:
sliced_image_list: list of SlicedImage
image_dir: str
Directory of the sliced image exports.
original_image_size: list of int
Size of the unsliced original image in [height, width]
num_total_invalid_segmentation: int
Number of invalid segmentation annotations.
"""

# read image
image_pil = read_image_as_pil(image)

image_width, image_height = image_pil.size
if not (image_width != 0 and image_height != 0):
raise RuntimeError(f"invalid image size: {image_pil.size} for 'slice_image'.")
slice_bboxes = get_slice_bboxes(
image_height=image_height,
image_width=image_width,
auto_slice_resolution=auto_slice_resolution,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
)

t0 = time.time()
n_ims = 0

# init images and annotations lists
sliced_image_result = SliceImageResult(original_image_size=[image_height, image_width])

image_pil_arr = np.asarray(image_pil)
# iterate over slices
for slice_bbox in slice_bboxes:
n_ims += 1

# extract image
tlx = slice_bbox[0]
tly = slice_bbox[1]
brx = slice_bbox[2]
bry = slice_bbox[3]
image_pil_slice = image_pil_arr[tly:bry, tlx:brx]

# create sliced image and append to sliced_image_result
sliced_image = SlicedImage(
image=image_pil_slice, starting_pixel=[slice_bbox[0], slice_bbox[1]]
)
sliced_image_result.add_sliced_image(sliced_image)

image_numpy = np.array(sliced_image_result.images)
shift_amount = np.array(sliced_image_result.starting_pixels)

return image_numpy, shift_amount


def calc_ratio_and_slice(orientation, slide=1, ratio=0.1):
"""
According to image resolution calculation overlap params
Args:
orientation: image capture angle
slide: sliding window
ratio: buffer value

Returns:
overlap params
"""
if orientation == "vertical":
slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide * 2, ratio, ratio
elif orientation == "horizontal":
slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide * 2, slide, ratio, ratio
elif orientation == "square":
slice_row, slice_col, overlap_height_ratio, overlap_width_ratio = slide, slide, ratio, ratio

return slice_row, slice_col, overlap_height_ratio, overlap_width_ratio # noqa


def calc_resolution_factor(resolution: int) -> int:
"""
According to image resolution calculate power(2,n) and return the closest smaller `n`.
Args:
resolution: the width and height of the image multiplied. such as 1024x720 = 737280

Returns:

"""
expo = 0
while np.power(2, expo) < resolution:
expo += 1

return expo - 1


def calc_aspect_ratio_orientation(width: int, height: int) -> str:
"""

Args:
width:
height:

Returns:
image capture orientation
"""

if width < height:
return "vertical"
elif width > height:
return "horizontal"
else:
return "square"


def calc_slice_and_overlap_params(resolution: str, height: int, width: int, orientation: str) -> List:
"""
This function calculate according to image resolution slice and overlap params.
Args:
resolution: str
height: int
width: int
orientation: str

Returns:
x_overlap, y_overlap, slice_width, slice_height
"""

if resolution == "medium":
split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
orientation, slide=1, ratio=0.8
)

elif resolution == "high":
split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
orientation, slide=2, ratio=0.4
)

elif resolution == "ultra-high":
split_row, split_col, overlap_height_ratio, overlap_width_ratio = calc_ratio_and_slice(
orientation, slide=4, ratio=0.4
)
else: # low condition
split_col = 1
split_row = 1
overlap_width_ratio = 1
overlap_height_ratio = 1

slice_height = height // split_col
slice_width = width // split_row

x_overlap = int(slice_width * overlap_width_ratio)
y_overlap = int(slice_height * overlap_height_ratio)

return x_overlap, y_overlap, slice_width, slice_height # noqa


def get_resolution_selector(res: str, height: int, width: int):
"""

Args:
res: resolution of image such as low, medium
height:
width:

Returns:
trigger slicing params function and return overlap params
"""
orientation = calc_aspect_ratio_orientation(width=width, height=height)
x_overlap, y_overlap, slice_width, slice_height = calc_slice_and_overlap_params(
resolution=res, height=height, width=width, orientation=orientation
)

return x_overlap, y_overlap, slice_width, slice_height


def get_auto_slice_params(height: int, width: int):
"""
According to Image HxW calculate overlap sliding window and buffer params
factor is the power value of 2 closest to the image resolution.
factor <= 18: low resolution image such as 300x300, 640x640
18 < factor <= 21: medium resolution image such as 1024x1024, 1336x960
21 < factor <= 24: high resolution image such as 2048x2048, 2048x4096, 4096x4096
factor > 24: ultra-high resolution image such as 6380x6380, 4096x8192
Args:
height:
width:

Returns:
slicing overlap params x_overlap, y_overlap, slice_width, slice_height
"""
resolution = height * width
factor = calc_resolution_factor(resolution)
if factor <= 18:
return get_resolution_selector("low", height=height, width=width)
elif 18 <= factor < 21:
return get_resolution_selector("medium", height=height, width=width)
elif 21 <= factor < 24:
return get_resolution_selector("high", height=height, width=width)
else:
return get_resolution_selector("ultra-high", height=height, width=width)

Loading…
Cancel
Save