You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

augmentations.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Image augmentation functions
  4. """
  5. import logging
  6. import math
  7. import random
  8. import cv2
  9. import numpy as np
  10. from utils.general import colorstr, segment2box, resample_segments, check_version
  11. from utils.metrics import bbox_ioa
  12. class Albumentations:
  13. # YOLOv5 Albumentations class (optional, only used if package is installed)
  14. def __init__(self):
  15. self.transform = None
  16. try:
  17. import albumentations as A
  18. check_version(A.__version__, '1.0.3') # version requirement
  19. self.transform = A.Compose([
  20. A.Blur(p=0.01),
  21. A.MedianBlur(p=0.01),
  22. A.ToGray(p=0.01),
  23. A.CLAHE(p=0.01),
  24. A.RandomBrightnessContrast(p=0.0),
  25. A.RandomGamma(p=0.0),
  26. A.ImageCompression(quality_lower=75, p=0.0)],
  27. bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
  28. logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
  29. except ImportError: # package not installed, skip
  30. pass
  31. except Exception as e:
  32. logging.info(colorstr('albumentations: ') + f'{e}')
  33. def __call__(self, im, labels, p=1.0):
  34. if self.transform and random.random() < p:
  35. new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
  36. im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
  37. return im, labels
  38. def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
  39. # HSV color-space augmentation
  40. if hgain or sgain or vgain:
  41. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  42. hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
  43. dtype = im.dtype # uint8
  44. x = np.arange(0, 256, dtype=r.dtype)
  45. lut_hue = ((x * r[0]) % 180).astype(dtype)
  46. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  47. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  48. im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  49. cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
  50. def hist_equalize(im, clahe=True, bgr=False):
  51. # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
  52. yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  53. if clahe:
  54. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  55. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  56. else:
  57. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  58. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  59. def replicate(im, labels):
  60. # Replicate labels
  61. h, w = im.shape[:2]
  62. boxes = labels[:, 1:].astype(int)
  63. x1, y1, x2, y2 = boxes.T
  64. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  65. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  66. x1b, y1b, x2b, y2b = boxes[i]
  67. bh, bw = y2b - y1b, x2b - x1b
  68. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  69. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  70. im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
  71. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  72. return im, labels
  73. def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  74. # Resize and pad image while meeting stride-multiple constraints
  75. shape = im.shape[:2] # current shape [height, width]
  76. if isinstance(new_shape, int):
  77. new_shape = (new_shape, new_shape)
  78. # Scale ratio (new / old)
  79. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  80. if not scaleup: # only scale down, do not scale up (for better val mAP)
  81. r = min(r, 1.0)
  82. # Compute padding
  83. ratio = r, r # width, height ratios
  84. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  85. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  86. if auto: # minimum rectangle
  87. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  88. elif scaleFill: # stretch
  89. dw, dh = 0.0, 0.0
  90. new_unpad = (new_shape[1], new_shape[0])
  91. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  92. dw /= 2 # divide padding into 2 sides
  93. dh /= 2
  94. if shape[::-1] != new_unpad: # resize
  95. im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
  96. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  97. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  98. im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  99. return im, ratio, (dw, dh)
  100. def random_perspective(im, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  101. border=(0, 0)):
  102. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  103. # targets = [cls, xyxy]
  104. height = im.shape[0] + border[0] * 2 # shape(h,w,c)
  105. width = im.shape[1] + border[1] * 2
  106. # Center
  107. C = np.eye(3)
  108. C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
  109. C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
  110. # Perspective
  111. P = np.eye(3)
  112. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  113. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  114. # Rotation and Scale
  115. R = np.eye(3)
  116. a = random.uniform(-degrees, degrees)
  117. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  118. s = random.uniform(1 - scale, 1 + scale)
  119. # s = 2 ** random.uniform(-scale, scale)
  120. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  121. # Shear
  122. S = np.eye(3)
  123. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  124. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  125. # Translation
  126. T = np.eye(3)
  127. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  128. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  129. # Combined rotation matrix
  130. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  131. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  132. if perspective:
  133. im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
  134. else: # affine
  135. im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  136. # Visualize
  137. # import matplotlib.pyplot as plt
  138. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  139. # ax[0].imshow(im[:, :, ::-1]) # base
  140. # ax[1].imshow(im2[:, :, ::-1]) # warped
  141. # Transform label coordinates
  142. n = len(targets)
  143. if n:
  144. use_segments = any(x.any() for x in segments)
  145. new = np.zeros((n, 4))
  146. if use_segments: # warp segments
  147. segments = resample_segments(segments) # upsample
  148. for i, segment in enumerate(segments):
  149. xy = np.ones((len(segment), 3))
  150. xy[:, :2] = segment
  151. xy = xy @ M.T # transform
  152. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  153. # clip
  154. new[i] = segment2box(xy, width, height)
  155. else: # warp boxes
  156. xy = np.ones((n * 4, 3))
  157. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  158. xy = xy @ M.T # transform
  159. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  160. # create new boxes
  161. x = xy[:, [0, 2, 4, 6]]
  162. y = xy[:, [1, 3, 5, 7]]
  163. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  164. # clip
  165. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  166. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  167. # filter candidates
  168. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  169. targets = targets[i]
  170. targets[:, 1:5] = new[i]
  171. return im, targets
  172. def copy_paste(im, labels, segments, p=0.5):
  173. # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
  174. n = len(segments)
  175. if p and n:
  176. h, w, c = im.shape # height, width, channels
  177. im_new = np.zeros(im.shape, np.uint8)
  178. for j in random.sample(range(n), k=round(p * n)):
  179. l, s = labels[j], segments[j]
  180. box = w - l[3], l[2], w - l[1], l[4]
  181. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  182. if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
  183. labels = np.concatenate((labels, [[l[0], *box]]), 0)
  184. segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
  185. cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
  186. result = cv2.bitwise_and(src1=im, src2=im_new)
  187. result = cv2.flip(result, 1) # augment segments (flip left-right)
  188. i = result > 0 # pixels to replace
  189. # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
  190. im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
  191. return im, labels, segments
  192. def cutout(im, labels, p=0.5):
  193. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  194. if random.random() < p:
  195. h, w = im.shape[:2]
  196. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  197. for s in scales:
  198. mask_h = random.randint(1, int(h * s)) # create random masks
  199. mask_w = random.randint(1, int(w * s))
  200. # box
  201. xmin = max(0, random.randint(0, w) - mask_w // 2)
  202. ymin = max(0, random.randint(0, h) - mask_h // 2)
  203. xmax = min(w, xmin + mask_w)
  204. ymax = min(h, ymin + mask_h)
  205. # apply random color mask
  206. im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  207. # return unobscured labels
  208. if len(labels) and s > 0.03:
  209. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  210. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  211. labels = labels[ioa < 0.60] # remove >60% obscured labels
  212. return labels
  213. def mixup(im, labels, im2, labels2):
  214. # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
  215. r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
  216. im = (im * r + im2 * (1 - r)).astype(np.uint8)
  217. labels = np.concatenate((labels, labels2), 0)
  218. return im, labels
  219. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  220. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  221. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  222. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  223. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  224. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates