@@ -1,5 +1,6 @@ | |||
import argparse | |||
import logging | |||
import math | |||
import os | |||
import random | |||
import time | |||
@@ -7,7 +8,6 @@ from pathlib import Path | |||
from threading import Thread | |||
from warnings import warn | |||
import math | |||
import numpy as np | |||
import torch.distributed as dist | |||
import torch.nn as nn | |||
@@ -217,7 +217,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
model.nc = nc # attach number of classes to model | |||
model.hyp = hyp # attach hyperparameters to model | |||
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) | |||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights | |||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights | |||
model.names = names | |||
# Start training | |||
@@ -238,7 +238,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
if opt.image_weights: | |||
# Generate indices | |||
if rank in [-1, 0]: | |||
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights | |||
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights | |||
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights | |||
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx | |||
# Broadcast if DDP | |||
@@ -330,7 +330,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
if rank in [-1, 0]: | |||
# mAP | |||
if ema: | |||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) | |||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) | |||
final_epoch = epoch + 1 == epochs | |||
if not opt.notest or final_epoch: # Calculate mAP | |||
results, maps, times = test.test(opt.data, |
@@ -1199,7 +1199,7 @@ | |||
"\n", | |||
"m1 = lambda x: x * torch.sigmoid(x)\n", | |||
"m2 = torch.nn.SiLU()\n", | |||
"profile(x=torch.randn(16, 3, 640, 640), [m1, m2], n=100)" | |||
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)" | |||
], | |||
"execution_count": null, | |||
"outputs": [] |
@@ -57,8 +57,8 @@ class FocalLoss(nn.Module): | |||
return loss.sum() | |||
else: # 'none' | |||
return loss | |||
class QFocalLoss(nn.Module): | |||
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) | |||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): | |||
@@ -71,7 +71,7 @@ class QFocalLoss(nn.Module): | |||
def forward(self, pred, true): | |||
loss = self.loss_fcn(pred, true) | |||
pred_prob = torch.sigmoid(pred) # prob from logits | |||
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) | |||
modulating_factor = torch.abs(true - pred_prob) ** self.gamma | |||
@@ -92,8 +92,8 @@ def compute_loss(p, targets, model): # predictions, targets, model | |||
h = model.hyp # hyperparameters | |||
# Define criteria | |||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device) | |||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device) | |||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights) | |||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) | |||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 | |||
cp, cn = smooth_BCE(eps=0.0) | |||
@@ -119,7 +119,7 @@ def compute_loss(p, targets, model): # predictions, targets, model | |||
# Regression | |||
pxy = ps[:, :2].sigmoid() * 2. - 0.5 | |||
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] | |||
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box | |||
pbox = torch.cat((pxy, pwh), 1) # predicted box | |||
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) | |||
lbox += (1.0 - iou).mean() # iou loss | |||
@@ -81,8 +81,8 @@ def profile(x, ops, n=100, device=None): | |||
# m1 = lambda x: x * torch.sigmoid(x) | |||
# m2 = nn.SiLU() | |||
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations | |||
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |||
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |||
x = x.to(device) | |||
x.requires_grad = True | |||
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') | |||
@@ -99,8 +99,11 @@ def profile(x, ops, n=100, device=None): | |||
t[0] = time_synchronized() | |||
y = m(x) | |||
t[1] = time_synchronized() | |||
_ = y.sum().backward() | |||
t[2] = time_synchronized() | |||
try: | |||
_ = y.sum().backward() | |||
t[2] = time_synchronized() | |||
except: # no backward method | |||
t[2] = float('nan') | |||
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward | |||
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward | |||