Browse Source

comment updates

5.0
Glenn Jocher 4 years ago
parent
commit
140d84cca1
2 changed files with 4 additions and 7 deletions
  1. +2
    -2
      train.py
  2. +2
    -5
      utils/utils.py

+ 2
- 2
train.py View File

@@ -152,13 +152,13 @@ def train(hyp):
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

# Distributed training
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available():
dist.init_process_group(backend='nccl', # distributed backend
init_method='tcp://127.0.0.1:9999', # init method
world_size=1, # number of nodes
rank=0) # node rank
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1
model = torch.nn.parallel.DistributedDataParallel(model)
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,

+ 2
- 5
utils/utils.py View File

@@ -503,6 +503,7 @@ def build_targets(p, targets, model):
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)

g = 0.5 # offset
style = 'rect4'
for i in range(det.nl):
anchors = det.anchors[i]
@@ -517,7 +518,6 @@ def build_targets(p, targets, model):
a, t = at[j], t.repeat(na, 1, 1)[j] # filter

# overlaps
g = 0.5 # offset
gxy = t[:, 2:4] # grid xy
z = torch.zeros_like(gxy)
if style == 'rect2':
@@ -878,10 +878,7 @@ def fitness(x):


def output_to_target(output, width, height):
"""
Convert a YOLO model output to target format
[batch_id, class_id, x, y, w, h, conf]
"""
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
if isinstance(output, torch.Tensor):
output = output.cpu().numpy()


Loading…
Cancel
Save