AIlib2/segutils/seg_detect.py

132 lines
4.8 KiB
Python

import torch
from core.models.bisenet import BiSeNet
from torchvision import transforms
import cv2,os
import numpy as np
from core.models.dinknet import DinkNet34
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import time
class SegModel(object):
def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
#self.args = args
self.model = BiSeNet(nclass)
#self.model = DinkNet34(nclass)
checkpoint = torch.load(weights)
self.modelsize = modelsize
self.model.load_state_dict(checkpoint['model'])
self.device = device
self.model= self.model.to(self.device)
'''self.composed_transforms = transforms.Compose([
transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
transforms.ToTensor()]) '''
self.mean = (0.335, 0.358, 0.332)
self.std = (0.141, 0.138, 0.143)
def eval(self,image,outsize=None):
imageW,imageH,imageC = image.shape
time0 = time.time()
image = self.preprocess_image(image)
time1 = time.time()
self.model.eval()
image = image.to(self.device)
with torch.no_grad():
output = self.model(image,outsize=outsize)
time2 = time.time()
pred = output.data.cpu().numpy()
pred = np.argmax(pred, axis=1)[0]#得到每行
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
return pred
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def preprocess_image(self,image):
time0 = time.time()
image = cv2.resize(image,(self.modelsize,self.modelsize))
time1 = time.time()
image = image.astype(np.float32)
image /= 255.0
time2 = time.time()
#image -= self.mean
image[:,:,0] -=self.mean[0]
image[:,:,1] -=self.mean[1]
image[:,:,2] -=self.mean[2]
time3 = time.time()
#image /= self.std
image[:,:,0] /= self.std[0]
image[:,:,1] /= self.std[1]
image[:,:,2] /= self.std[2]
time4 = time.time()
image = np.transpose(image, ( 2, 0, 1))
time5 = time.time()
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
return image
def get_ms(t1,t0):
return (t1-t0)*1000.0
if __name__=='__main__':
#os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
'''
image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
nclass = 5
weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
'''
image_url = 'temp_pics/DJI_0645.JPG'
nclass = 2
#weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
#weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
for i in range(10):
image_array0 = cv2.imread(image_url)
imageH,imageW,_ = image_array0.shape
#print('###line84:',image_array0.shape)
image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
#image_in = segmodel.preprocess_image(image_array)
pred = segmodel.eval(image_array,outsize=None)
time0=time.time()
binary = pred.copy()
time1=time.time()
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
time2=time.time()
print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
##计算findconturs时间与大小的关系
binary0 = binary.copy()
for ii,ss in enumerate([22,256,512,1024,2048]):
time0=time.time()
image = cv2.resize(binary0,(ss,ss))
time1=time.time()
if ii ==0:
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
else:
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
time2=time.time()
print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))