Browse Source

New Colors() class (#2963)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
57812df68c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 17 deletions
  1. +3
    -6
      detect.py
  2. +2
    -3
      models/common.py
  3. +16
    -8
      utils/plots.py

+ 3
- 6
detect.py View File

@@ -11,7 +11,7 @@ from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import plot_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


@@ -34,6 +34,7 @@ def detect(opt):
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16

@@ -52,10 +53,6 @@ def detect(opt):
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)

# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
@@ -112,7 +109,7 @@ def detect(opt):
c = int(cls) # integer class
label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')

plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=opt.line_thickness)
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
if opt.save_crop:
save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)


+ 2
- 3
models/common.py View File

@@ -14,7 +14,7 @@ from torch.cuda import amp

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import color_list, plot_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized


@@ -312,7 +312,6 @@ class Detections:
self.s = shape # inference BCHW shape

def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
colors = color_list()
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
if pred is not None:
@@ -325,7 +324,7 @@ class Detections:
if crop:
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
else: # all others
plot_one_box(box, im, label=label, color=colors[int(cls) % 10])
plot_one_box(box, im, label=label, color=colors(cls))

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint:

+ 16
- 8
utils/plots.py View File

@@ -26,12 +26,22 @@ matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg') # for writing to files only


def color_list():
# Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
def hex2rgb(h):
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()]
self.n = len(self.palette)

def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c

@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)

colors = Colors() # create instance for 'from utils.plots import colors'


def hist2d(x, y, n=100):
@@ -137,7 +147,6 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
h = math.ceil(scale_factor * h)
w = math.ceil(scale_factor * w)

colors = color_list() # list of colors
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, img in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
@@ -168,7 +177,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
cls = int(classes[j])
color = colors[cls % len(colors)]
color = colors(cls)
cls = names[cls] if names else cls
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
@@ -276,7 +285,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
@@ -302,7 +310,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
for cls, *box in labels[:1000]:
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')


Loading…
Cancel
Save