Browse Source

Remove `formats` variable to avoid `pd` conflict (#7993)

* Remove `formats` variable to avoid `pd` conflict

* Update export.py
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
945579699a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 9 deletions
  1. +5
    -5
      export.py
  2. +2
    -4
      utils/benchmarks.py

+ 5
- 5
export.py View File

): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] # to lowercase include = [x.lower() for x in include] # to lowercase
formats = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in formats]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights


im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection


# Update model # Update model
if half and not (coreml or xml):
if half and not coreml and not xml:
im, model = im.half(), model.half() # to FP16 im, model = im.half(), model.half() # to FP16
model.train() if train else model.eval() # training mode = no Detect() layer grid construction model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules(): for k, m in model.named_modules():
if any((saved_model, pb, tflite, edgetpu, tfjs)): if any((saved_model, pb, tflite, edgetpu, tfjs)):
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
model, f[5] = export_saved_model(model.cpu(), model, f[5] = export_saved_model(model.cpu(),
im, im,
file, file,

+ 2
- 4
utils/benchmarks.py View File

pt_only=False, # test PyTorch only pt_only=False, # test PyTorch only
): ):
y, t = [], time.time() y, t = [], time.time()
formats = export.export_formats()
device = select_device(device) device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
try: try:
assert i != 9, 'Edge TPU not supported' assert i != 9, 'Edge TPU not supported'
assert i != 10, 'TF.js not supported' assert i != 10, 'TF.js not supported'
pt_only=False, # test PyTorch only pt_only=False, # test PyTorch only
): ):
y, t = [], time.time() y, t = [], time.time()
formats = export.export_formats()
device = select_device(device) device = select_device(device)
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
try: try:
w = weights if f == '-' else \ w = weights if f == '-' else \
export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights

Loading…
Cancel
Save