import os import torch import time import cv2 from PIL import Image import torchvision.transforms as standard_transforms from p2pnetUtils.p2pnet import build from loguru import logger class p2NnetModel(object): def __init__(self, weights=None, par={}): self.par = par self.device = torch.device(par['device']) assert os.path.exists(weights), "%s not exists" self.model = build(par) self.model.to(self.device) checkpoint = torch.load(weights, map_location=self.device) self.model.load_state_dict(checkpoint['model']) self.model.eval() self.transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def eval(self, image): t0 = time.time() img_raw = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) img_raw = Image.fromarray(img_raw) width, height = img_raw.size new_width = width // 128 * 128 new_height = height // 128 * 128 img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS) img = self.transform(img_raw) samples = torch.Tensor(img).unsqueeze(0) samples = samples.to(self.device) preds = self.model(samples) t3 = time.time() timeOut = 'p2pnet :%.1f (pre-process:%.1f, ) ' % (self.get_ms(t3, t0), self.get_ms(t3, t0)) return preds def get_ms(self,t1,t0): return (t1-t0)*1000.0