Browse Source

Implement default class names (#1609)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
d929bb656c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      models/yolo.py

+ 5
- 4
models/yolo.py View File

@@ -1,16 +1,16 @@
import argparse
import logging
import math
import sys
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

import math
import torch
import torch.nn as nn

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.autoanchor import check_anchor_order
@@ -82,6 +82,7 @@ class Model(nn.Module):
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors

Loading…
Cancel
Save