Преглед на файлове

kmean_anchors() update

5.0
Glenn Jocher преди 4 години
родител
ревизия
8fa3724072
променени са 1 файла, в които са добавени 21 реда и са изтрити 8 реда
  1. +21
    -8
      utils/utils.py

+ 21
- 8
utils/utils.py Целия файл

@@ -16,6 +16,7 @@ import numpy as np
import torch
import torch.nn as nn
import torchvision
import yaml
from scipy.signal import butter, filtfilt
from tqdm import tqdm

@@ -686,12 +687,23 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images


def kmean_anchors(path='./data/coco128.txt', n=9, img_size=(640, 640), thr=0.20, gen=1000):
# Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
# n: number of anchors
# img_size: (min, max) image size used for multi-scale training (can be same values)
# thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
# gen: generations to evolve anchors using genetic algorithm
def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20, gen=1000):
""" Creates kmeans-evolved anchors from training dataset

Arguments:
path: path to dataset *.yaml
n: number of anchors
img_size: (min, max) image size used for multi-scale training (can be same values)
thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
gen: generations to evolve anchors using genetic algorithm

Return:
k: kmeans evolved anchors

Usage:
from utils.utils import *; _ = kmean_anchors()
"""

from utils.datasets import LoadImagesAndLabels

def print_results(k):
@@ -727,7 +739,9 @@ def kmean_anchors(path='./data/coco128.txt', n=9, img_size=(640, 640), thr=0.20,

# Get label wh
wh = []
dataset = LoadImagesAndLabels(path, augment=True, rect=True)
with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions
for s, l in zip(dataset.shapes, dataset.labels):
# wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
@@ -771,7 +785,6 @@ def kmean_anchors(path='./data/coco128.txt', n=9, img_size=(640, 640), thr=0.20,
f, k = fg, kg.copy()
print_results(k)
k = print_results(k)

return k



Loading…
Отказ
Запис