소스 검색

Implement default class names (#1609)

5.0
Glenn Jocher GitHub 3 년 전
부모
커밋
d929bb656c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1개의 변경된 파일5개의 추가작업 그리고 4개의 파일을 삭제
  1. +5
    -4
      models/yolo.py

+ 5
- 4
models/yolo.py 파일 보기

@@ -1,16 +1,16 @@
import argparse
import logging
import math
import sys
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

import math
import torch
import torch.nn as nn

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.autoanchor import check_anchor_order
@@ -82,6 +82,7 @@ class Model(nn.Module):
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors

Loading…
취소
저장