145 lines
4.9 KiB
Python
145 lines
4.9 KiB
Python
import torch
|
|
import sys,os
|
|
sys.path.extend(['segutils'])
|
|
from core.models.bisenet import BiSeNet
|
|
from torchvision import transforms
|
|
import cv2,glob
|
|
import numpy as np
|
|
from core.models.dinknet import DinkNet34
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
class SegModel(object):
|
|
def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
|
|
#self.args = args
|
|
self.model = BiSeNet(nclass)
|
|
#self.model = DinkNet34(nclass)
|
|
checkpoint = torch.load(weights)
|
|
self.modelsize = modelsize
|
|
self.model.load_state_dict(checkpoint['model'])
|
|
self.device = device
|
|
self.model= self.model.to(self.device)
|
|
'''self.composed_transforms = transforms.Compose([
|
|
|
|
transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
|
|
transforms.ToTensor()]) '''
|
|
self.mean = (0.335, 0.358, 0.332)
|
|
self.std = (0.141, 0.138, 0.143)
|
|
def eval(self,image):
|
|
time0 = time.time()
|
|
imageH,imageW,imageC = image.shape
|
|
image = self.preprocess_image(image)
|
|
time1 = time.time()
|
|
self.model.eval()
|
|
image = image.to(self.device)
|
|
with torch.no_grad():
|
|
output = self.model(image)
|
|
|
|
time2 = time.time()
|
|
pred = output.data.cpu().numpy()
|
|
pred = np.argmax(pred, axis=1)[0]#得到每行
|
|
time3 = time.time()
|
|
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
|
|
time4 = time.time()
|
|
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
|
|
|
|
#print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
|
|
return pred,outstr
|
|
def get_ms(self,t1,t0):
|
|
return (t1-t0)*1000.0
|
|
def preprocess_image(self,image):
|
|
|
|
time0 = time.time()
|
|
image = cv2.resize(image,(self.modelsize,self.modelsize))
|
|
time0 = time.time()
|
|
image = image.astype(np.float32)
|
|
image /= 255.0
|
|
|
|
image[:,:,0] -=self.mean[0]
|
|
image[:,:,1] -=self.mean[1]
|
|
image[:,:,2] -=self.mean[2]
|
|
|
|
image[:,:,0] /= self.std[0]
|
|
image[:,:,1] /= self.std[1]
|
|
image[:,:,2] /= self.std[2]
|
|
image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
|
|
#image -= self.mean
|
|
#image /= self.std
|
|
image = np.transpose(image, ( 2, 0, 1))
|
|
|
|
image = torch.from_numpy(image).float()
|
|
image = image.unsqueeze(0)
|
|
|
|
|
|
return image
|
|
|
|
def get_ms(t1,t0):
|
|
return (t1-t0)*1000.0
|
|
|
|
|
|
def get_largest_contours(contours):
|
|
areas = [cv2.contourArea(x) for x in contours]
|
|
max_area = max(areas)
|
|
max_id = areas.index(max_area)
|
|
|
|
return max_id
|
|
|
|
if __name__=='__main__':
|
|
image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
|
|
nclass = 2
|
|
#weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
|
|
weights = '../weights/BiSeNet/checkpoint.pth'
|
|
|
|
segmodel = SegModel(nclass=nclass,weights=weights)
|
|
|
|
image_urls=glob.glob('../../river_demo/images/slope/*')
|
|
out_dir ='../../river_demo/images/results/';
|
|
os.makedirs(out_dir,exist_ok=True)
|
|
for im,image_url in enumerate(image_urls[0:]):
|
|
#image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
|
|
image_array0 = cv2.imread(image_url)
|
|
H,W,C = image_array0.shape
|
|
time_1=time.time()
|
|
pred,outstr = segmodel.eval(image_array0 )
|
|
|
|
#plt.figure(1);plt.imshow(pred);
|
|
#plt.show()
|
|
binary0 = pred.copy()
|
|
|
|
|
|
time0 = time.time()
|
|
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
|
max_id = -1
|
|
if len(contours)>0:
|
|
max_id = get_largest_contours(contours)
|
|
binary0[:,:] = 0
|
|
cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
|
|
|
|
time1 = time.time()
|
|
|
|
|
|
time2 = time.time()
|
|
|
|
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
|
|
time3 = time.time()
|
|
out_url='%s/%s'%(out_dir,os.path.basename(image_url))
|
|
ret = cv2.imwrite(out_url,image_array0)
|
|
time4 = time.time()
|
|
|
|
print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
|
|
#print(outstr)
|
|
#plt.figure(0);plt.imshow(pred)
|
|
#plt.figure(1);plt.imshow(image_array0)
|
|
#plt.figure(2);plt.imshow(binary0)
|
|
#plt.show()
|
|
|
|
#print(out_url,ret)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|