|
- import torch
- from core.models.bisenet import BiSeNet
- from torchvision import transforms
- import cv2,os
- import numpy as np
- from core.models.dinknet import DinkNet34
- import matplotlib.pyplot as plt
-
- import matplotlib.pyplot as plt
- import time
- class SegModel(object):
- def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
- #self.args = args
- self.model = BiSeNet(nclass)
- #self.model = DinkNet34(nclass)
- checkpoint = torch.load(weights)
- self.modelsize = modelsize
- self.model.load_state_dict(checkpoint['model'])
- self.device = device
- self.model= self.model.to(self.device)
- '''self.composed_transforms = transforms.Compose([
-
- transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
- transforms.ToTensor()]) '''
- self.mean = (0.335, 0.358, 0.332)
- self.std = (0.141, 0.138, 0.143)
- def eval(self,image,outsize=None):
- imageW,imageH,imageC = image.shape
- time0 = time.time()
- image = self.preprocess_image(image)
- time1 = time.time()
- self.model.eval()
- image = image.to(self.device)
- with torch.no_grad():
- output = self.model(image,outsize=outsize)
-
- time2 = time.time()
- pred = output.data.cpu().numpy()
- pred = np.argmax(pred, axis=1)[0]#得到每行
- time3 = time.time()
- pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
- time4 = time.time()
- print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
- return pred
- def get_ms(self,t1,t0):
- return (t1-t0)*1000.0
- def preprocess_image(self,image):
-
- time0 = time.time()
- image = cv2.resize(image,(self.modelsize,self.modelsize))
-
- time1 = time.time()
- image = image.astype(np.float32)
- image /= 255.0
-
- time2 = time.time()
- #image -= self.mean
- image[:,:,0] -=self.mean[0]
- image[:,:,1] -=self.mean[1]
- image[:,:,2] -=self.mean[2]
-
- time3 = time.time()
- #image /= self.std
-
- image[:,:,0] /= self.std[0]
- image[:,:,1] /= self.std[1]
- image[:,:,2] /= self.std[2]
-
-
- time4 = time.time()
- image = np.transpose(image, ( 2, 0, 1))
- time5 = time.time()
- image = torch.from_numpy(image).float()
- image = image.unsqueeze(0)
- print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
-
- return image
-
-
-
- def get_ms(t1,t0):
- return (t1-t0)*1000.0
-
- if __name__=='__main__':
-
-
-
- #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
- '''
- image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
- nclass = 5
- weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
- '''
-
-
- image_url = 'temp_pics/DJI_0645.JPG'
- nclass = 2
- #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
- weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
- #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
-
-
-
- segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
- for i in range(10):
- image_array0 = cv2.imread(image_url)
- imageH,imageW,_ = image_array0.shape
- #print('###line84:',image_array0.shape)
- image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
- #image_in = segmodel.preprocess_image(image_array)
- pred = segmodel.eval(image_array,outsize=None)
- time0=time.time()
- binary = pred.copy()
- time1=time.time()
- contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- time2=time.time()
- print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
-
-
- ##计算findconturs时间与大小的关系
- binary0 = binary.copy()
- for ii,ss in enumerate([22,256,512,1024,2048]):
- time0=time.time()
- image = cv2.resize(binary0,(ss,ss))
- time1=time.time()
- if ii ==0:
- contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- else:
- contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- time2=time.time()
- print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))
-
|