TensorRT 7 export fix (#6235)
This commit is contained in:
parent
33a67b4918
commit
6865d19a92
|
|
@ -174,7 +174,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
||||||
check_requirements(('tensorrt',))
|
check_requirements(('tensorrt',))
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
|
|
||||||
if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
||||||
grid = model.model[-1].anchor_grid
|
grid = model.model[-1].anchor_grid
|
||||||
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
||||||
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue