AIlib2/p2pNet.py

44 lines
1.5 KiB
Python

import os
import torch
import time
import cv2
from PIL import Image
import torchvision.transforms as standard_transforms
from p2pnetUtils.p2pnet import build
from loguru import logger
class p2NnetModel(object):
def __init__(self, weights=None, par={}):
self.par = par
self.device = torch.device(par['device'])
assert os.path.exists(weights), "%s not exists"
self.model = build(par)
self.model.to(self.device)
checkpoint = torch.load(weights, map_location=self.device)
self.model.load_state_dict(checkpoint['model'])
self.model.eval()
self.transform = standard_transforms.Compose([
standard_transforms.ToTensor(),
standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def eval(self, image):
t0 = time.time()
img_raw = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img_raw = Image.fromarray(img_raw)
width, height = img_raw.size
new_width = width // 128 * 128
new_height = height // 128 * 128
img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)
img = self.transform(img_raw)
samples = torch.Tensor(img).unsqueeze(0)
samples = samples.to(self.device)
preds = self.model(samples)
t3 = time.time()
timeOut = 'p2pnet :%.1f (pre-process:%.1f, ) ' % (self.get_ms(t3, t0), self.get_ms(t3, t0))
return preds
def get_ms(self,t1,t0):
return (t1-t0)*1000.0