Browse Source

kmean_anchors() update

5.0
Glenn Jocher 4 years ago
parent
commit
8fa3724072
1 changed files with 21 additions and 8 deletions
  1. +21
    -8
      utils/utils.py

+ 21
- 8
utils/utils.py View File

import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
import yaml
from scipy.signal import butter, filtfilt from scipy.signal import butter, filtfilt
from tqdm import tqdm from tqdm import tqdm


shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images




def kmean_anchors(path='./data/coco128.txt', n=9, img_size=(640, 640), thr=0.20, gen=1000):
# Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
# n: number of anchors
# img_size: (min, max) image size used for multi-scale training (can be same values)
# thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
# gen: generations to evolve anchors using genetic algorithm
def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20, gen=1000):
""" Creates kmeans-evolved anchors from training dataset

Arguments:
path: path to dataset *.yaml
n: number of anchors
img_size: (min, max) image size used for multi-scale training (can be same values)
thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
gen: generations to evolve anchors using genetic algorithm

Return:
k: kmeans evolved anchors

Usage:
from utils.utils import *; _ = kmean_anchors()
"""

from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels


def print_results(k): def print_results(k):


# Get label wh # Get label wh
wh = [] wh = []
dataset = LoadImagesAndLabels(path, augment=True, rect=True)
with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions
for s, l in zip(dataset.shapes, dataset.labels): for s, l in zip(dataset.shapes, dataset.labels):
# wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh # wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
f, k = fg, kg.copy() f, k = fg, kg.copy()
print_results(k) print_results(k)
k = print_results(k) k = print_results(k)

return k return k





Loading…
Cancel
Save