import torch from core.models.bisenet import BiSeNet from torchvision import transforms import cv2,os import numpy as np from core.models.dinknet import DinkNet34 import matplotlib.pyplot as plt import matplotlib.pyplot as plt import time class SegModel(object): def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'): #self.args = args self.model = BiSeNet(nclass) #self.model = DinkNet34(nclass) checkpoint = torch.load(weights) self.modelsize = modelsize self.model.load_state_dict(checkpoint['model']) self.device = device self.model= self.model.to(self.device) '''self.composed_transforms = transforms.Compose([ transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)), transforms.ToTensor()]) ''' self.mean = (0.335, 0.358, 0.332) self.std = (0.141, 0.138, 0.143) def eval(self,image,outsize=None): imageW,imageH,imageC = image.shape time0 = time.time() image = self.preprocess_image(image) time1 = time.time() self.model.eval() image = image.to(self.device) with torch.no_grad(): output = self.model(image,outsize=outsize) time2 = time.time() pred = output.data.cpu().numpy() pred = np.argmax(pred, axis=1)[0]#得到每行 time3 = time.time() pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH)) time4 = time.time() print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) )) return pred def get_ms(self,t1,t0): return (t1-t0)*1000.0 def preprocess_image(self,image): time0 = time.time() image = cv2.resize(image,(self.modelsize,self.modelsize)) time1 = time.time() image = image.astype(np.float32) image /= 255.0 time2 = time.time() #image -= self.mean image[:,:,0] -=self.mean[0] image[:,:,1] -=self.mean[1] image[:,:,2] -=self.mean[2] time3 = time.time() #image /= self.std image[:,:,0] /= self.std[0] image[:,:,1] /= self.std[1] image[:,:,2] /= self.std[2] time4 = time.time() image = np.transpose(image, ( 2, 0, 1)) time5 = time.time() image = torch.from_numpy(image).float() image = image.unsqueeze(0) print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) ) return image def get_ms(t1,t0): return (t1-t0)*1000.0 if __name__=='__main__': #os.environ["CUDA_VISIBLE_DEVICES"] = str('4') ''' image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg' nclass = 5 weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth' ''' image_url = 'temp_pics/DJI_0645.JPG' nclass = 2 #weights = '../weights/segmentation/BiSeNet/checkpoint.pth' weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth' #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth' segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4') for i in range(10): image_array0 = cv2.imread(image_url) imageH,imageW,_ = image_array0.shape #print('###line84:',image_array0.shape) image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR) #image_in = segmodel.preprocess_image(image_array) pred = segmodel.eval(image_array,outsize=None) time0=time.time() binary = pred.copy() time1=time.time() contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) time2=time.time() print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) )) ##计算findconturs时间与大小的关系 binary0 = binary.copy() for ii,ss in enumerate([22,256,512,1024,2048]): time0=time.time() image = cv2.resize(binary0,(ss,ss)) time1=time.time() if ii ==0: contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) else: contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) time2=time.time() print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))