Browse Source

model definition update

5.0
Glenn Jocher 4 years ago
parent
commit
5bee686649
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      models/yolo.py

+ 4
- 2
models/yolo.py View File





class Model(nn.Module): class Model(nn.Module):
def __init__(self, model_yaml='yolov5s.yaml'): # cfg, number of classes, depth-width gains
def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
super(Model, self).__init__() super(Model, self).__init__()
with open(model_yaml) as f: with open(model_yaml) as f:
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
if nc:
self.md['nc'] = nc # override yaml value


# Define model # Define model
self.model, self.save, ch = parse_model(self.md, ch=[3]) # model, savelist, ch_out
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))]) # print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])


# Build strides, anchors # Build strides, anchors

Loading…
Cancel
Save