44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
import os
|
|
import torch
|
|
import time
|
|
import cv2
|
|
from PIL import Image
|
|
import torchvision.transforms as standard_transforms
|
|
from p2pnetUtils.p2pnet import build
|
|
from loguru import logger
|
|
|
|
class p2NnetModel(object):
|
|
def __init__(self, weights=None, par={}):
|
|
|
|
self.par = par
|
|
self.device = torch.device(par['device'])
|
|
assert os.path.exists(weights), "%s not exists"
|
|
self.model = build(par)
|
|
self.model.to(self.device)
|
|
checkpoint = torch.load(weights, map_location=self.device)
|
|
self.model.load_state_dict(checkpoint['model'])
|
|
self.model.eval()
|
|
self.transform = standard_transforms.Compose([
|
|
standard_transforms.ToTensor(),
|
|
standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def eval(self, image):
|
|
t0 = time.time()
|
|
img_raw = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
img_raw = Image.fromarray(img_raw)
|
|
width, height = img_raw.size
|
|
new_width = width // 128 * 128
|
|
new_height = height // 128 * 128
|
|
img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)
|
|
img = self.transform(img_raw)
|
|
samples = torch.Tensor(img).unsqueeze(0)
|
|
samples = samples.to(self.device)
|
|
|
|
preds = self.model(samples)
|
|
t3 = time.time()
|
|
timeOut = 'p2pnet :%.1f (pre-process:%.1f, ) ' % (self.get_ms(t3, t0), self.get_ms(t3, t0))
|
|
return preds
|
|
|
|
def get_ms(self,t1,t0):
|
|
return (t1-t0)*1000.0 |