瀏覽代碼

Implement default class names (#1609)

5.0
Glenn Jocher GitHub 3 年之前
父節點
當前提交
d929bb656c
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 1 個檔案被更改,包括 5 行新增4 行删除
  1. +5
    -4
      models/yolo.py

+ 5
- 4
models/yolo.py 查看文件

@@ -1,16 +1,16 @@
import argparse
import logging
import math
import sys
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

import math
import torch
import torch.nn as nn

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.autoanchor import check_anchor_order
@@ -82,6 +82,7 @@ class Model(nn.Module):
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors

Loading…
取消
儲存