@@ -1,6 +1,7 @@ | |||
# This file contains modules common to various models | |||
import math | |||
import numpy as np | |||
import requests | |||
import torch | |||
@@ -240,7 +241,7 @@ class Detections: | |||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized | |||
self.n = len(self.pred) | |||
def display(self, pprint=False, show=False, save=False): | |||
def display(self, pprint=False, show=False, save=False, render=False): | |||
colors = color_list() | |||
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): | |||
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' | |||
@@ -248,19 +249,21 @@ class Detections: | |||
for c in pred[:, -1].unique(): | |||
n = (pred[:, -1] == c).sum() # detections per class | |||
str += f'{n} {self.names[int(c)]}s, ' # add to string | |||
if show or save: | |||
if show or save or render: | |||
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np | |||
for *box, conf, cls in pred: # xyxy, confidence, class | |||
# str += '%s %.2f, ' % (names[int(cls)], conf) # label | |||
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot | |||
if pprint: | |||
print(str) | |||
if show: | |||
img.show(f'Image {i}') # show | |||
if save: | |||
f = f'results{i}.jpg' | |||
str += f"saved to '{f}'" | |||
img.save(f) # save | |||
if show: | |||
img.show(f'Image {i}') # show | |||
if pprint: | |||
print(str) | |||
if render: | |||
self.imgs[i] = np.asarray(img) | |||
def print(self): | |||
self.display(pprint=True) # print results | |||
@@ -271,6 +274,10 @@ class Detections: | |||
def save(self): | |||
self.display(save=True) # save results | |||
def render(self): | |||
self.display(render=True) # render results | |||
return self.imgs | |||
def __len__(self): | |||
return self.n | |||
@@ -28,7 +28,7 @@ from utils.autoanchor import check_anchors | |||
from utils.datasets import create_dataloader | |||
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ | |||
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ | |||
check_requirements, print_mutation, set_logging, one_cycle | |||
check_requirements, print_mutation, set_logging, one_cycle, colorstr | |||
from utils.google_utils import attempt_download | |||
from utils.loss import compute_loss | |||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | |||
@@ -44,7 +44,7 @@ except ImportError: | |||
def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
logger.info(f'Hyperparameters {hyp}') | |||
logger.info(colorstr('blue', 'bold', 'Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) | |||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \ | |||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank | |||
@@ -233,9 +233,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
scaler = amp.GradScaler(enabled=cuda) | |||
logger.info('Image sizes %g train, %g test\n' | |||
'Using %g dataloader workers\nLogging results to %s\n' | |||
'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs)) | |||
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' | |||
f'Using {dataloader.num_workers} dataloader workers\n' | |||
f'Logging results to {save_dir}\n' | |||
f'Starting training for {epochs} epochs...') | |||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ | |||
model.train() | |||
@@ -25,6 +25,7 @@ from utils.torch_utils import init_torch_seeds | |||
torch.set_printoptions(linewidth=320, precision=5, profile='long') | |||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 | |||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) | |||
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads | |||
def set_logging(rank=-1): | |||
@@ -117,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100): | |||
def colorstr(*input): | |||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') | |||
*prefix, str = input # color arguments, string | |||
*prefix, string = input # color arguments, string | |||
colors = {'black': '\033[30m', # basic colors | |||
'red': '\033[31m', | |||
'green': '\033[32m', | |||
@@ -136,9 +137,9 @@ def colorstr(*input): | |||
'bright_white': '\033[97m', | |||
'end': '\033[0m', # misc | |||
'bold': '\033[1m', | |||
'undelrine': '\033[4m'} | |||
'underline': '\033[4m'} | |||
return ''.join(colors[x] for x in prefix) + str + colors['end'] | |||
return ''.join(colors[x] for x in prefix) + f'{string}' + colors['end'] | |||
def labels_to_class_weights(labels, nc=80): |