Update yolo.py channel array (#2223)
This commit is contained in:
parent
7b833e37bf
commit
f8464b4f66
|
|
@ -2,7 +2,6 @@ import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -213,43 +212,27 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
||||||
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
|
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
|
||||||
C3]:
|
C3]:
|
||||||
c1, c2 = ch[f], args[0]
|
c1, c2 = ch[f], args[0]
|
||||||
|
if c2 != no: # if not output
|
||||||
# Normal
|
c2 = make_divisible(c2 * gw, 8)
|
||||||
# if i > 0 and args[0] != no: # channel expansion factor
|
|
||||||
# ex = 1.75 # exponential (default 2.0)
|
|
||||||
# e = math.log(c2 / ch[1]) / math.log(2)
|
|
||||||
# c2 = int(ch[1] * ex ** e)
|
|
||||||
# if m != Focus:
|
|
||||||
|
|
||||||
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
|
|
||||||
|
|
||||||
# Experimental
|
|
||||||
# if i > 0 and args[0] != no: # channel expansion factor
|
|
||||||
# ex = 1 + gw # exponential (default 2.0)
|
|
||||||
# ch1 = 32 # ch[1]
|
|
||||||
# e = math.log(c2 / ch1) / math.log(2) # level 1-n
|
|
||||||
# c2 = int(ch1 * ex ** e)
|
|
||||||
# if m != Focus:
|
|
||||||
# c2 = make_divisible(c2, 8) if c2 != no else c2
|
|
||||||
|
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
if m in [BottleneckCSP, C3]:
|
if m in [BottleneckCSP, C3]:
|
||||||
args.insert(2, n)
|
args.insert(2, n) # number of repeats
|
||||||
n = 1
|
n = 1
|
||||||
elif m is nn.BatchNorm2d:
|
elif m is nn.BatchNorm2d:
|
||||||
args = [ch[f]]
|
args = [ch[f]]
|
||||||
elif m is Concat:
|
elif m is Concat:
|
||||||
c2 = sum([ch[x if x < 0 else x + 1] for x in f])
|
c2 = sum([ch[x] for x in f])
|
||||||
elif m is Detect:
|
elif m is Detect:
|
||||||
args.append([ch[x + 1] for x in f])
|
args.append([ch[x] for x in f])
|
||||||
if isinstance(args[1], int): # number of anchors
|
if isinstance(args[1], int): # number of anchors
|
||||||
args[1] = [list(range(args[1] * 2))] * len(f)
|
args[1] = [list(range(args[1] * 2))] * len(f)
|
||||||
elif m is Contract:
|
elif m is Contract:
|
||||||
c2 = ch[f if f < 0 else f + 1] * args[0] ** 2
|
c2 = ch[f] * args[0] ** 2
|
||||||
elif m is Expand:
|
elif m is Expand:
|
||||||
c2 = ch[f if f < 0 else f + 1] // args[0] ** 2
|
c2 = ch[f] // args[0] ** 2
|
||||||
else:
|
else:
|
||||||
c2 = ch[f if f < 0 else f + 1]
|
c2 = ch[f]
|
||||||
|
|
||||||
m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
|
m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
|
||||||
t = str(m)[8:-2].replace('__main__.', '') # module type
|
t = str(m)[8:-2].replace('__main__.', '') # module type
|
||||||
|
|
@ -258,6 +241,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
||||||
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
|
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
|
||||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||||
layers.append(m_)
|
layers.append(m_)
|
||||||
|
if i == 0:
|
||||||
|
ch = []
|
||||||
ch.append(c2)
|
ch.append(c2)
|
||||||
return nn.Sequential(*layers), sorted(save)
|
return nn.Sequential(*layers), sorted(save)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue