|
|
@@ -475,9 +475,9 @@ def run( |
|
|
|
): |
|
|
|
t = time.time() |
|
|
|
include = [x.lower() for x in include] # to lowercase |
|
|
|
formats = tuple(export_formats()['Argument'][1:]) # --include arguments |
|
|
|
flags = [x in include for x in formats] |
|
|
|
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}' |
|
|
|
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments |
|
|
|
flags = [x in include for x in fmts] |
|
|
|
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' |
|
|
|
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans |
|
|
|
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights |
|
|
|
|
|
|
@@ -499,7 +499,7 @@ def run( |
|
|
|
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection |
|
|
|
|
|
|
|
# Update model |
|
|
|
if half and not (coreml or xml): |
|
|
|
if half and not coreml and not xml: |
|
|
|
im, model = im.half(), model.half() # to FP16 |
|
|
|
model.train() if train else model.eval() # training mode = no Detect() layer grid construction |
|
|
|
for k, m in model.named_modules(): |
|
|
@@ -531,7 +531,7 @@ def run( |
|
|
|
if any((saved_model, pb, tflite, edgetpu, tfjs)): |
|
|
|
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 |
|
|
|
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` |
|
|
|
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' |
|
|
|
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' |
|
|
|
model, f[5] = export_saved_model(model.cpu(), |
|
|
|
im, |
|
|
|
file, |