|
|
@@ -46,7 +46,7 @@ class ocrModel(object): |
|
|
|
logger = trt.Logger(trt.Logger.ERROR)
|
|
|
|
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
|
|
|
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
|
|
|
self.context = self.model.create_execution_context()
|
|
|
|
#self.context = self.model.create_execution_context()
|
|
|
|
|
|
|
|
elif self.infer_type=='pth':
|
|
|
|
if par['mode']=='ch':
|
|
|
@@ -70,7 +70,7 @@ class ocrModel(object): |
|
|
|
self.model.eval()
|
|
|
|
preds = self.model(image)
|
|
|
|
else:
|
|
|
|
preds,trtstr=OcrTrtForward(self.model,[image],self.context)
|
|
|
|
preds,trtstr=OcrTrtForward(self.model,[image],False)
|
|
|
|
|
|
|
|
t2 = time.time()
|
|
|
|
preds_size = torch.IntTensor([preds.size(0)]*1)
|