Bladeren bron

Merge PIL and OpenCV in `plot_one_box(use_pil=False)` (#4416)

* Merge PIL and OpenCV box plotting functions

* Add ASCII check to plot_one_box

* Cleanup

* Cleanup2
modifyDataloader
Glenn Jocher GitHub 3 jaren geleden
bovenliggende
commit
2da4e7acf7
Geen bekende sleutel gevonden voor deze handtekening in de database GPG sleutel-ID: 4AEE18F83AFDEB23
4 gewijzigde bestanden met toevoegingen van 35 en 32 verwijderingen
  1. +1
    -1
      detect.py
  2. +1
    -1
      models/common.py
  3. +6
    -1
      utils/general.py
  4. +27
    -29
      utils/plots.py

+ 1
- 1
detect.py Bestand weergeven

@@ -156,7 +156,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
if save_img or save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=line_thickness)
im0 = plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_width=line_thickness)
if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)


+ 1
- 1
models/common.py Bestand weergeven

@@ -354,7 +354,7 @@ class Detections:
if crop:
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
else: # all others
plot_one_box(box, im, label=label, color=colors(cls))
im = plot_one_box(box, im, label=label, color=colors(cls))
else:
str += '(no detections)'


+ 6
- 1
utils/general.py Bestand weergeven

@@ -110,9 +110,14 @@ def is_pip():
return 'site-packages' in Path(__file__).absolute().parts


def is_ascii(str=''):
# Is string composed of all ASCII (no UTF) characters?
return len(str.encode().decode('ascii', 'ignore')) == len(str)


def emojis(str=''):
# Return platform-dependent emoji-safe version of string
return str.encode().decode(encoding='ascii', errors='ignore') if platform.system() == 'Windows' else str
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str


def file_size(file):

+ 27
- 29
utils/plots.py Bestand weergeven

@@ -1,20 +1,19 @@
# Plotting utils

import math
from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont

from utils.general import xywh2xyxy, xyxy2xywh
from utils.general import is_ascii, xyxy2xywh, xywh2xyxy
from utils.metrics import fitness

# Settings
@@ -65,32 +64,31 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
return filtfilt(b, a, data) # forward-backward filter


def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
# Plots one bounding box on image 'im' using OpenCV
def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), label=None, line_width=3, use_pil=False):
# Plots one xyxy box on image im with label
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)


def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=3):
# Plots one bounding box on image 'im' using PIL
im = Image.fromarray(im)
draw = ImageDraw.Draw(im)
line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
draw.rectangle(box, width=line_thickness, outline=color) # plot
if label:
font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
txt_width, txt_height = font.getsize(label)
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
return np.asarray(im)
lw = line_width or max(int(min(im.size) / 200), 2) # line width

if use_pil or not is_ascii(label): # use PIL
im = Image.fromarray(im)
draw = ImageDraw.Draw(im)
draw.rectangle(box, width=lw + 1, outline=color) # plot
if label:
font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
txt_width, txt_height = font.getsize(label)
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
draw.text((box[0], box[1] - txt_height + 1), label, fill=txt_color, font=font)
return np.asarray(im)
else: # use OpenCV
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(im, c1, c2, color, thickness=lw, lineType=cv2.LINE_AA)
if label:
tf = max(lw - 1, 1) # font thickness
txt_width, txt_height = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]
c2 = c1[0] + txt_width, c1[1] - txt_height - 3
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
return im


def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
@@ -180,7 +178,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
cls = names[cls] if names else cls
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
mosaic = plot_one_box(box, mosaic, label=label, color=color, line_width=tl)

# Draw image filename labels
if paths:

Laden…
Annuleren
Opslaan