|
|
@@ -67,51 +67,59 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): |
|
|
|
return filtfilt(b, a, data) # forward-backward filter |
|
|
|
|
|
|
|
|
|
|
|
def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), label=None, line_width=3, use_pil=False): |
|
|
|
# Plots one xyxy box on image im with label |
|
|
|
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.' |
|
|
|
lw = line_width or max(int(min(im.size) / 200), 2) # line width |
|
|
|
|
|
|
|
if use_pil or (label is not None and not is_ascii(label)): # use PIL |
|
|
|
im = Image.fromarray(im) |
|
|
|
draw = ImageDraw.Draw(im) |
|
|
|
draw.rectangle(box, width=lw + 1, outline=color) # plot |
|
|
|
if label: |
|
|
|
font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12)) |
|
|
|
txt_width, txt_height = font.getsize(label) |
|
|
|
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color) |
|
|
|
draw.text((box[0], box[1] - txt_height + 1), label, fill=txt_color, font=font) |
|
|
|
return np.asarray(im) |
|
|
|
else: # use OpenCV |
|
|
|
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
|
|
|
cv2.rectangle(im, c1, c2, color, thickness=lw, lineType=cv2.LINE_AA) |
|
|
|
if label: |
|
|
|
tf = max(lw - 1, 1) # font thickness |
|
|
|
txt_width, txt_height = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] |
|
|
|
c2 = c1[0] + txt_width, c1[1] - txt_height - 3 |
|
|
|
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled |
|
|
|
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA) |
|
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() |
|
|
|
# Compares the two methods for width-height anchor multiplication |
|
|
|
# https://github.com/ultralytics/yolov3/issues/168 |
|
|
|
x = np.arange(-4.0, 4.0, .1) |
|
|
|
ya = np.exp(x) |
|
|
|
yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 |
|
|
|
|
|
|
|
fig = plt.figure(figsize=(6, 3), tight_layout=True) |
|
|
|
plt.plot(x, ya, '.-', label='YOLOv3') |
|
|
|
plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2') |
|
|
|
plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6') |
|
|
|
plt.xlim(left=-4, right=4) |
|
|
|
plt.ylim(bottom=0, top=6) |
|
|
|
plt.xlabel('input') |
|
|
|
plt.ylabel('output') |
|
|
|
plt.grid() |
|
|
|
plt.legend() |
|
|
|
fig.savefig('comparison.png', dpi=200) |
|
|
|
class Annotator: |
|
|
|
# YOLOv5 PIL Annotator class |
|
|
|
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=True): |
|
|
|
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.' |
|
|
|
self.pil = pil |
|
|
|
if self.pil: # use PIL |
|
|
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
|
|
self.draw = ImageDraw.Draw(self.im) |
|
|
|
s = sum(self.im.size) / 2 # mean shape |
|
|
|
f = font_size or max(round(s * 0.035), 12) |
|
|
|
try: |
|
|
|
self.font = ImageFont.truetype(font, size=f) |
|
|
|
except: # download TTF |
|
|
|
url = "https://github.com/ultralytics/yolov5/releases/download/v1.0/" + font |
|
|
|
torch.hub.download_url_to_file(url, font) |
|
|
|
self.font = ImageFont.truetype(font, size=f) |
|
|
|
self.fh = self.font.getsize('a')[1] - 3 # font height |
|
|
|
else: # use cv2 |
|
|
|
self.im = im |
|
|
|
s = sum(im.shape) / 2 # mean shape |
|
|
|
self.lw = line_width or max(round(s * 0.003), 2) # line width |
|
|
|
|
|
|
|
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)): |
|
|
|
# Add one xyxy box to image with label |
|
|
|
if self.pil or not is_ascii(label): |
|
|
|
self.draw.rectangle(box, width=self.lw, outline=color) # box |
|
|
|
if label: |
|
|
|
w = self.font.getsize(label)[0] # text width |
|
|
|
self.draw.rectangle([box[0], box[1] - self.fh, box[0] + w + 1, box[1] + 1], fill=color) |
|
|
|
self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') |
|
|
|
else: # cv2 |
|
|
|
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
|
|
|
cv2.rectangle(self.im, c1, c2, color, thickness=self.lw, lineType=cv2.LINE_AA) |
|
|
|
if label: |
|
|
|
tf = max(self.lw - 1, 1) # font thickness |
|
|
|
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] |
|
|
|
c2 = c1[0] + w, c1[1] - h - 3 |
|
|
|
cv2.rectangle(self.im, c1, c2, color, -1, cv2.LINE_AA) # filled |
|
|
|
cv2.putText(self.im, label, (c1[0], c1[1] - 2), 0, self.lw / 3, txt_color, thickness=tf, |
|
|
|
lineType=cv2.LINE_AA) |
|
|
|
|
|
|
|
def rectangle(self, xy, fill=None, outline=None, width=1): |
|
|
|
# Add rectangle to image (PIL-only) |
|
|
|
self.draw.rectangle(xy, fill, outline, width) |
|
|
|
|
|
|
|
def text(self, xy, text, txt_color=(255, 255, 255)): |
|
|
|
# Add text to image (PIL-only) |
|
|
|
w, h = self.font.getsize(text) # text width, height |
|
|
|
self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font) |
|
|
|
|
|
|
|
def result(self): |
|
|
|
# Return annotated image as array |
|
|
|
return np.asarray(self.im) |
|
|
|
|
|
|
|
|
|
|
|
def output_to_target(output): |
|
|
@@ -123,82 +131,65 @@ def output_to_target(output): |
|
|
|
return np.array(targets) |
|
|
|
|
|
|
|
|
|
|
|
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): |
|
|
|
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16): |
|
|
|
# Plot image grid with labels |
|
|
|
|
|
|
|
if isinstance(images, torch.Tensor): |
|
|
|
images = images.cpu().float().numpy() |
|
|
|
if isinstance(targets, torch.Tensor): |
|
|
|
targets = targets.cpu().numpy() |
|
|
|
|
|
|
|
# un-normalise |
|
|
|
if np.max(images[0]) <= 1: |
|
|
|
images *= 255 |
|
|
|
|
|
|
|
tl = 3 # line thickness |
|
|
|
tf = max(tl - 1, 1) # font thickness |
|
|
|
images *= 255.0 # de-normalise (optional) |
|
|
|
bs, _, h, w = images.shape # batch size, _, height, width |
|
|
|
bs = min(bs, max_subplots) # limit plot images |
|
|
|
ns = np.ceil(bs ** 0.5) # number of subplots (square) |
|
|
|
|
|
|
|
# Check if we should resize |
|
|
|
scale_factor = max_size / max(h, w) |
|
|
|
if scale_factor < 1: |
|
|
|
h = math.ceil(scale_factor * h) |
|
|
|
w = math.ceil(scale_factor * w) |
|
|
|
|
|
|
|
# Build Image |
|
|
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init |
|
|
|
for i, img in enumerate(images): |
|
|
|
for i, im in enumerate(images): |
|
|
|
if i == max_subplots: # if last batch has fewer images than we expect |
|
|
|
break |
|
|
|
|
|
|
|
block_x = int(w * (i // ns)) |
|
|
|
block_y = int(h * (i % ns)) |
|
|
|
|
|
|
|
img = img.transpose(1, 2, 0) |
|
|
|
if scale_factor < 1: |
|
|
|
img = cv2.resize(img, (w, h)) |
|
|
|
|
|
|
|
mosaic[block_y:block_y + h, block_x:block_x + w, :] = img |
|
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin |
|
|
|
im = im.transpose(1, 2, 0) |
|
|
|
mosaic[y:y + h, x:x + w, :] = im |
|
|
|
|
|
|
|
# Resize (optional) |
|
|
|
scale = max_size / ns / max(h, w) |
|
|
|
if scale < 1: |
|
|
|
h = math.ceil(scale * h) |
|
|
|
w = math.ceil(scale * w) |
|
|
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) |
|
|
|
|
|
|
|
# Annotate |
|
|
|
fs = int(h * ns * 0.02) # font size |
|
|
|
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs) |
|
|
|
for i in range(i + 1): |
|
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin |
|
|
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders |
|
|
|
if paths: |
|
|
|
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames |
|
|
|
if len(targets) > 0: |
|
|
|
image_targets = targets[targets[:, 0] == i] |
|
|
|
boxes = xywh2xyxy(image_targets[:, 2:6]).T |
|
|
|
classes = image_targets[:, 1].astype('int') |
|
|
|
labels = image_targets.shape[1] == 6 # labels if no conf column |
|
|
|
conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred) |
|
|
|
ti = targets[targets[:, 0] == i] # image targets |
|
|
|
boxes = xywh2xyxy(ti[:, 2:6]).T |
|
|
|
classes = ti[:, 1].astype('int') |
|
|
|
labels = ti.shape[1] == 6 # labels if no conf column |
|
|
|
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred) |
|
|
|
|
|
|
|
if boxes.shape[1]: |
|
|
|
if boxes.max() <= 1.01: # if normalized with tolerance 0.01 |
|
|
|
boxes[[0, 2]] *= w # scale to pixels |
|
|
|
boxes[[1, 3]] *= h |
|
|
|
elif scale_factor < 1: # absolute coords need scale if image scales |
|
|
|
boxes *= scale_factor |
|
|
|
boxes[[0, 2]] += block_x |
|
|
|
boxes[[1, 3]] += block_y |
|
|
|
for j, box in enumerate(boxes.T): |
|
|
|
cls = int(classes[j]) |
|
|
|
elif scale < 1: # absolute coords need scale if image scales |
|
|
|
boxes *= scale |
|
|
|
boxes[[0, 2]] += x |
|
|
|
boxes[[1, 3]] += y |
|
|
|
for j, box in enumerate(boxes.T.tolist()): |
|
|
|
cls = classes[j] |
|
|
|
color = colors(cls) |
|
|
|
cls = names[cls] if names else cls |
|
|
|
if labels or conf[j] > 0.25: # 0.25 conf thresh |
|
|
|
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) |
|
|
|
mosaic = plot_one_box(box, mosaic, label=label, color=color, line_width=tl) |
|
|
|
|
|
|
|
# Draw image filename labels |
|
|
|
if paths: |
|
|
|
label = Path(paths[i]).name[:40] # trim to 40 char |
|
|
|
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] |
|
|
|
cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf, |
|
|
|
lineType=cv2.LINE_AA) |
|
|
|
|
|
|
|
# Image border |
|
|
|
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) |
|
|
|
|
|
|
|
if fname: |
|
|
|
r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size |
|
|
|
mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA) |
|
|
|
# cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save |
|
|
|
Image.fromarray(mosaic).save(fname) # PIL save |
|
|
|
return mosaic |
|
|
|
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}' |
|
|
|
annotator.box_label(box, label, color=color) |
|
|
|
annotator.im.save(fname) # save |
|
|
|
|
|
|
|
|
|
|
|
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): |