@@ -27,6 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo | |||
""" | |||
from pathlib import Path | |||
from models.common import AutoShape | |||
from models.experimental import attempt_load | |||
from models.yolo import Model | |||
from utils.downloads import attempt_download | |||
@@ -55,7 +56,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo | |||
if len(ckpt['model'].names) == classes: | |||
model.names = ckpt['model'].names # set class names attribute | |||
if autoshape: | |||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS | |||
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS | |||
return model.to(device) | |||
except Exception as e: |
@@ -23,7 +23,7 @@ from utils.datasets import exif_transpose, letterbox | |||
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible, | |||
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh) | |||
from utils.plots import Annotator, colors, save_one_box | |||
from utils.torch_utils import time_sync | |||
from utils.torch_utils import copy_attr, time_sync | |||
def autopad(k, p=None): # kernel, padding | |||
@@ -405,12 +405,10 @@ class AutoShape(nn.Module): | |||
def __init__(self, model): | |||
super().__init__() | |||
LOGGER.info('Adding AutoShape... ') | |||
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes | |||
self.model = model.eval() | |||
def autoshape(self): | |||
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() | |||
return self | |||
def _apply(self, fn): | |||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers | |||
self = super()._apply(fn) |
@@ -22,8 +22,7 @@ from models.experimental import * | |||
from utils.autoanchor import check_anchor_order | |||
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args | |||
from utils.plots import feature_visualization | |||
from utils.torch_utils import (copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, | |||
time_sync) | |||
from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync | |||
try: | |||
import thop # for FLOPs computation | |||
@@ -226,12 +225,6 @@ class Model(nn.Module): | |||
self.info() | |||
return self | |||
def autoshape(self): # add AutoShape module | |||
LOGGER.info('Adding AutoShape... ') | |||
m = AutoShape(self) # wrap model | |||
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes | |||
return m | |||
def info(self, verbose=False, img_size=640): # print model information | |||
model_info(self, verbose, img_size) | |||