* DetectMultiBackend() `--half` handling * CI fixes * rename .half to .fp16 to avoid conflict * warmup fix * val update * engine update * engine updatemodifyDataloader
@@ -89,19 +89,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
# Load model | |||
device = select_device(device) | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) | |||
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) | |||
stride, names, pt = model.stride, model.names, model.pt | |||
imgsz = check_img_size(imgsz, s=stride) # check image size | |||
# Half | |||
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA | |||
if pt or jit: | |||
model.model.half() if half else model.model.float() | |||
elif engine and model.trt_fp16_input != half: | |||
LOGGER.info('model ' + ( | |||
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') | |||
half = model.trt_fp16_input | |||
# Dataloader | |||
if webcam: | |||
view_img = check_imshow() | |||
@@ -114,12 +105,12 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
vid_path, vid_writer = [None] * bs, [None] * bs | |||
# Run inference | |||
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz), half=half) # warmup | |||
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup | |||
dt, seen = [0.0, 0.0, 0.0], 0 | |||
for path, im, im0s, vid_cap, s in dataset: | |||
t1 = time_sync() | |||
im = torch.from_numpy(im).to(device) | |||
im = im.half() if half else im.float() # uint8 to fp16/32 | |||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |||
im /= 255 # 0 - 255 to 0.0 - 1.0 | |||
if len(im.shape) == 3: | |||
im = im[None] # expand for batch dim |
@@ -277,7 +277,7 @@ class Concat(nn.Module): | |||
class DetectMultiBackend(nn.Module): | |||
# YOLOv5 MultiBackend class for python inference on various backends | |||
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): | |||
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False): | |||
# Usage: | |||
# PyTorch: weights = *.pt | |||
# TorchScript: *.torchscript | |||
@@ -297,6 +297,7 @@ class DetectMultiBackend(nn.Module): | |||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend | |||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | |||
w = attempt_download(w) # download if not local | |||
fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 | |||
if data: # data.yaml path (optional) | |||
with open(data, errors='ignore') as f: | |||
names = yaml.safe_load(f)['names'] # class names | |||
@@ -305,11 +306,13 @@ class DetectMultiBackend(nn.Module): | |||
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) | |||
stride = max(int(model.stride.max()), 32) # model stride | |||
names = model.module.names if hasattr(model, 'module') else model.names # get class names | |||
model.half() if fp16 else model.float() | |||
self.model = model # explicitly assign for to(), cpu(), cuda(), half() | |||
elif jit: # TorchScript | |||
LOGGER.info(f'Loading {w} for TorchScript inference...') | |||
extra_files = {'config.txt': ''} # model metadata | |||
model = torch.jit.load(w, _extra_files=extra_files) | |||
model.half() if fp16 else model.float() | |||
if extra_files['config.txt']: | |||
d = json.loads(extra_files['config.txt']) # extra_files dict | |||
stride, names = int(d['stride']), d['names'] | |||
@@ -338,11 +341,11 @@ class DetectMultiBackend(nn.Module): | |||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download | |||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 | |||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | |||
trt_fp16_input = False | |||
logger = trt.Logger(trt.Logger.INFO) | |||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | |||
model = runtime.deserialize_cuda_engine(f.read()) | |||
bindings = OrderedDict() | |||
fp16 = False # default updated below | |||
for index in range(model.num_bindings): | |||
name = model.get_binding_name(index) | |||
dtype = trt.nptype(model.get_binding_dtype(index)) | |||
@@ -350,7 +353,7 @@ class DetectMultiBackend(nn.Module): | |||
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) | |||
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) | |||
if model.binding_is_input(index) and dtype == np.float16: | |||
trt_fp16_input = True | |||
fp16 = True | |||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | |||
context = model.create_execution_context() | |||
batch_size = bindings['images'].shape[0] | |||
@@ -458,11 +461,11 @@ class DetectMultiBackend(nn.Module): | |||
y = torch.tensor(y) if isinstance(y, np.ndarray) else y | |||
return (y, []) if val else y | |||
def warmup(self, imgsz=(1, 3, 640, 640), half=False): | |||
def warmup(self, imgsz=(1, 3, 640, 640)): | |||
# Warmup model by running inference once | |||
if self.pt or self.jit or self.onnx or self.engine: # warmup types | |||
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models | |||
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image | |||
im = torch.zeros(*imgsz).to(self.device).type(torch.half if self.fp16 else torch.float) # input image | |||
self.forward(im) # warmup | |||
@staticmethod |
@@ -125,7 +125,6 @@ def run(data, | |||
training = model is not None | |||
if training: # called by train.py | |||
device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model | |||
half &= device.type != 'cpu' # half precision only supported on CUDA | |||
model.half() if half else model.float() | |||
else: # called directly | |||
@@ -136,23 +135,17 @@ def run(data, | |||
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir | |||
# Load model | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) | |||
stride, pt, jit, onnx, engine = model.stride, model.pt, model.jit, model.onnx, model.engine | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) | |||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine | |||
imgsz = check_img_size(imgsz, s=stride) # check image size | |||
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA | |||
if pt or jit: | |||
model.model.half() if half else model.model.float() | |||
elif engine: | |||
half = model.fp16 # FP16 supported on limited backends with CUDA | |||
if engine: | |||
batch_size = model.batch_size | |||
if model.trt_fp16_input != half: | |||
LOGGER.info('model ' + ( | |||
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') | |||
half = model.trt_fp16_input | |||
else: | |||
half = False | |||
batch_size = 1 # export.py models default to batch-size 1 | |||
device = torch.device('cpu') | |||
LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends') | |||
device = model.device | |||
if not pt or jit: | |||
batch_size = 1 # export.py models default to batch-size 1 | |||
LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') | |||
# Data | |||
data = check_dataset(data) # check | |||
@@ -166,7 +159,7 @@ def run(data, | |||
# Dataloader | |||
if not training: | |||
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz), half=half) # warmup | |||
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup | |||
pad = 0.0 if task in ('speed', 'benchmark') else 0.5 | |||
rect = False if task == 'benchmark' else pt # square inference for benchmarks | |||
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images |