Update loss criteria constructor (#1711)
This commit is contained in:
parent
799724108f
commit
8bc0027afc
8
train.py
8
train.py
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
@ -7,7 +8,6 @@ from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -217,7 +217,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
|
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
|
||||||
model.names = names
|
model.names = names
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
|
|
@ -238,7 +238,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||||
if opt.image_weights:
|
if opt.image_weights:
|
||||||
# Generate indices
|
# Generate indices
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
|
||||||
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
|
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
|
||||||
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
|
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
|
||||||
# Broadcast if DDP
|
# Broadcast if DDP
|
||||||
|
|
@ -330,7 +330,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
# mAP
|
# mAP
|
||||||
if ema:
|
if ema:
|
||||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
||||||
final_epoch = epoch + 1 == epochs
|
final_epoch = epoch + 1 == epochs
|
||||||
if not opt.notest or final_epoch: # Calculate mAP
|
if not opt.notest or final_epoch: # Calculate mAP
|
||||||
results, maps, times = test.test(opt.data,
|
results, maps, times = test.test(opt.data,
|
||||||
|
|
|
||||||
|
|
@ -1199,7 +1199,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
||||||
"m2 = torch.nn.SiLU()\n",
|
"m2 = torch.nn.SiLU()\n",
|
||||||
"profile(x=torch.randn(16, 3, 640, 640), [m1, m2], n=100)"
|
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
|
||||||
],
|
],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
|
|
|
||||||
|
|
@ -92,8 +92,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights)
|
||||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
||||||
|
|
||||||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||||
cp, cn = smooth_BCE(eps=0.0)
|
cp, cn = smooth_BCE(eps=0.0)
|
||||||
|
|
@ -119,7 +119,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
# Regression
|
# Regression
|
||||||
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
||||||
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
||||||
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
|
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||||
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
||||||
lbox += (1.0 - iou).mean() # iou loss
|
lbox += (1.0 - iou).mean() # iou loss
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,8 +99,11 @@ def profile(x, ops, n=100, device=None):
|
||||||
t[0] = time_synchronized()
|
t[0] = time_synchronized()
|
||||||
y = m(x)
|
y = m(x)
|
||||||
t[1] = time_synchronized()
|
t[1] = time_synchronized()
|
||||||
_ = y.sum().backward()
|
try:
|
||||||
t[2] = time_synchronized()
|
_ = y.sum().backward()
|
||||||
|
t[2] = time_synchronized()
|
||||||
|
except: # no backward method
|
||||||
|
t[2] = float('nan')
|
||||||
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||||
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue