AIlib2/segutils/segmodel_BiseNet.py

145 lines
4.9 KiB
Python

import torch
import sys,os
sys.path.extend(['segutils'])
from core.models.bisenet import BiSeNet
from torchvision import transforms
import cv2,glob
import numpy as np
from core.models.dinknet import DinkNet34
import matplotlib.pyplot as plt
import time
class SegModel(object):
def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
#self.args = args
self.model = BiSeNet(nclass)
#self.model = DinkNet34(nclass)
checkpoint = torch.load(weights)
self.modelsize = modelsize
self.model.load_state_dict(checkpoint['model'])
self.device = device
self.model= self.model.to(self.device)
'''self.composed_transforms = transforms.Compose([
transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
transforms.ToTensor()]) '''
self.mean = (0.335, 0.358, 0.332)
self.std = (0.141, 0.138, 0.143)
def eval(self,image):
time0 = time.time()
imageH,imageW,imageC = image.shape
image = self.preprocess_image(image)
time1 = time.time()
self.model.eval()
image = image.to(self.device)
with torch.no_grad():
output = self.model(image)
time2 = time.time()
pred = output.data.cpu().numpy()
pred = np.argmax(pred, axis=1)[0]#得到每行
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
#print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
return pred,outstr
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def preprocess_image(self,image):
time0 = time.time()
image = cv2.resize(image,(self.modelsize,self.modelsize))
time0 = time.time()
image = image.astype(np.float32)
image /= 255.0
image[:,:,0] -=self.mean[0]
image[:,:,1] -=self.mean[1]
image[:,:,2] -=self.mean[2]
image[:,:,0] /= self.std[0]
image[:,:,1] /= self.std[1]
image[:,:,2] /= self.std[2]
image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
#image -= self.mean
#image /= self.std
image = np.transpose(image, ( 2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
return image
def get_ms(t1,t0):
return (t1-t0)*1000.0
def get_largest_contours(contours):
areas = [cv2.contourArea(x) for x in contours]
max_area = max(areas)
max_id = areas.index(max_area)
return max_id
if __name__=='__main__':
image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
nclass = 2
#weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
weights = '../weights/BiSeNet/checkpoint.pth'
segmodel = SegModel(nclass=nclass,weights=weights)
image_urls=glob.glob('../../river_demo/images/slope/*')
out_dir ='../../river_demo/images/results/';
os.makedirs(out_dir,exist_ok=True)
for im,image_url in enumerate(image_urls[0:]):
#image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
image_array0 = cv2.imread(image_url)
H,W,C = image_array0.shape
time_1=time.time()
pred,outstr = segmodel.eval(image_array0 )
#plt.figure(1);plt.imshow(pred);
#plt.show()
binary0 = pred.copy()
time0 = time.time()
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
max_id = -1
if len(contours)>0:
max_id = get_largest_contours(contours)
binary0[:,:] = 0
cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
time1 = time.time()
time2 = time.time()
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
time3 = time.time()
out_url='%s/%s'%(out_dir,os.path.basename(image_url))
ret = cv2.imwrite(out_url,image_array0)
time4 = time.time()
print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
#print(outstr)
#plt.figure(0);plt.imshow(pred)
#plt.figure(1);plt.imshow(image_array0)
#plt.figure(2);plt.imshow(binary0)
#plt.show()
#print(out_url,ret)