AIlib2/segutils/segMultiOutModel.py

378 lines
15 KiB
Python

import torch
from core.models.bisenet import BiSeNet,BiSeNet_MultiOutput
from torchvision import transforms
import cv2,os,glob
import numpy as np
from core.models.dinknet import DinkNet34
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import time
class SegModel(object):
def __init__(self, nclass=2,model = None,weights=None,modelsize=512,device='cuda:3',multiOutput=False):
#self.args = args
self.model = model
#self.model = DinkNet34(nclass)
checkpoint = torch.load(weights)
self.modelsize = modelsize
self.model.load_state_dict(checkpoint['model'])
self.device = device
self.multiOutput = multiOutput
self.model= self.model.to(self.device)
'''self.composed_transforms = transforms.Compose([
transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
transforms.ToTensor()]) '''
self.mean = (0.335, 0.358, 0.332)
self.std = (0.141, 0.138, 0.143)
#mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)
def eval(self,image,outsize=None,smooth_kernel=0):
imageH,imageW,imageC = image.shape
time0 = time.time()
image = self.preprocess_image(image)
time1 = time.time()
self.model.eval()
image = image.to(self.device)
with torch.no_grad():
output = self.model(image,test_flag=True,smooth_kernel = 0)
time2 = time.time()
if self.multiOutput:
pred = [outputx.data.cpu().numpy()[0] for outputx in output]
else:
pred = output.data.cpu().numpy()
pred = pred[0]
time3 = time.time()
if self.multiOutput:
pred = [ cv2.blur(predx,(smooth_kernel,smooth_kernel) ) for predx in pred]
pred = [cv2.resize(predx.astype(np.uint8),(imageW,imageH)) for predx in pred[0:2]]
else:
pred = cv2.blur(pred,(smooth_kernel,smooth_kernel) )
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH),interpolation = cv2.INTER_NEAREST)
time4 = time.time()
print('##line52:pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
return pred
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def preprocess_image(self,image):
time0 = time.time()
image = cv2.resize(image,(self.modelsize,self.modelsize))
time1 = time.time()
image = image.astype(np.float32)
image /= 255.0
time2 = time.time()
#image = image * 3.2 - 1.6
image[:,:,0] -=self.mean[0]
image[:,:,1] -=self.mean[1]
image[:,:,2] -=self.mean[2]
time3 = time.time()
image[:,:,0] /= self.std[0]
image[:,:,1] /= self.std[1]
image[:,:,2] /= self.std[2]
time4 = time.time()
image = np.transpose(image, ( 2, 0, 1))
time5 = time.time()
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
return image
def get_ms(t1,t0):
return (t1-t0)*1000.0
def test():
#os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
'''
image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
nclass = 5
weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
'''
image_url = 'temp_pics/DJI_0645.JPG'
nclass = 2
#weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
#weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
model = BiSeNet(nclass)
segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:4')
for i in range(10):
image_array0 = cv2.imread(image_url)
imageH,imageW,_ = image_array0.shape
#print('###line84:',image_array0.shape)
image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
#image_in = segmodel.preprocess_image(image_array)
pred = segmodel.eval(image_array,outsize=None)
time0=time.time()
binary = pred.copy()
time1=time.time()
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
time2=time.time()
print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
label_dic={'landcover':[[0, 0, 0], [255, 0, 0], [0,255,0], [0,0,255], [255,255,0]],
'deepRoad':[[0,0,0],[255,0,0]],
'water':[[0,0,0],[255,255,255]],
'water_building':[[0,0,0],[0,0,255],[255,0,0]],
'floater':[[0,0,0], [0,255,0],[255,255,0],[255,0,255],[0,128, 255], [255,0,0], [0,255,255] ]
}
def index2color(label_mask,label_colours):
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
label_cnt = len(label_colours)
for ll in range(0, label_cnt):
r[label_mask == ll] = label_colours[ll][0]
g[label_mask == ll] = label_colours[ll][1]
b[label_mask == ll] = label_colours[ll][2]
rgb = np.stack((b, g,r), axis=-1)
return rgb.astype(np.uint8)
def get_largest_contours(contours):
areas = [cv2.contourArea(x) for x in contours]
max_area = max(areas)
max_id = areas.index(max_area)
return max_id
def result_merge_sep(image,mask_colors):
#mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
for mask_color in mask_colors:
mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
ishow = 2
#plt.figure(1);plt.imshow(mask_map);
for index,color in zip(indexes,colors):
mask_binaray = (mask_map == index).astype(np.uint8)
contours, hierarchy = cv2.findContours(mask_binaray,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
d=hierarchy[0,:,3]<0 ;
contours = np.array(contours,dtype=object)[d]
cv2.drawContours(image,contours,-1,color[::-1],3)
#plt.figure(ishow);plt.imshow(mask_binaray);ishow+=1
#plt.show()
return image
def result_merge(image,mask_colors):
#mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
for mask_color in mask_colors:
mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
mask_binary = (mask_map>0).astype(np.uint8)
contours, hierarchy = cv2.findContours(mask_binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
d=hierarchy[0,:,3]<0 ; contours = np.array(contours)[d]
cv2.drawContours(image,contours,-1,colors[0][::-1],3)
coors = np.array([(np.mean(contours_x ,axis=0)+0.5).astype(np.int32)[0] for contours_x in contours])
#print(mask_map.shape,coors.shape)
typess = mask_map[ coors[:,1],coors[:,0]]
#for jj,iclass in enumerate(typess):
#print(iclass,colors)
# cv2.drawContours(image,contours,-1, colors[iclass][::-1],3)
return image
def test_floater():
from core.models.dinknet import DinkNet34_MultiOutput
#create_model('DinkNet34_MultiOutput',[2,5])
image_url = 'temp_pics/DJI_0645.JPG'
nclass = [2,7]
outresult=True
weights = 'runs/thFloater/BiSeNet_MultiOutput/train/experiment_4/checkpoint.pth'
model = BiSeNet_MultiOutput(nclass)
outdir='temp'
image_dir = '/host/workspace/WJ/data/thFloater/val/images/'
image_url_list=glob.glob('%s/*'%(image_dir))
segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:9',multiOutput=True)
for i,image_url in enumerate(image_url_list[0:10]) :
image_array0 = cv2.imread(image_url)
image_array0 = cv2.cvtColor(image_array0, cv2.COLOR_BGR2RGB) # cv2默认为bgr顺序
imageH,imageW,_ = image_array0.shape
#image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
pred = segmodel.eval(image_array,outsize=None)
time0=time.time()
if isinstance(pred,list):
binary = [predx.copy() for predx in pred]
time1=time.time()
mask_colors=[ { 'mask':pred[0] ,'index':range(1,2),'color':label_dic['water'][0:] },
{ 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] } ]
result_draw = result_merge(image_array0,mask_colors)
time2=time.time()
if outresult:
basename=os.path.splitext( os.path.basename(image_url))[0]
outname=os.path.join(outdir,basename+'_draw.png')
cv2.imwrite(outname,result_draw[:,:,:])
print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
def test_water_buildings():
from core.models.bisenet import BiSeNet
#image_url = 'temp_pics/DJI_0645.JPG'
nclass = 3
outresult=True
weights = 'runs/thWaterBuilding/BiSeNet/train/experiment_2/checkpoint.pth'
model = BiSeNet(nclass)
outdir='temp'
image_dir = '/home/thsw/WJ/data/river_buildings/'
#image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
image_url_list=glob.glob('%s/*'%(image_dir))
segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:0',multiOutput=False)
for i,image_url in enumerate(image_url_list[0:]) :
#image_url = '/home/thsw/WJ/data/THWaterBuilding/val/images/0anWqgmO9rGe1n8P.png'
image_array0 = cv2.imread(image_url)
imageH,imageW,_ = image_array0.shape
image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
pred = segmodel.eval(image_array,outsize=None)
time0=time.time()
if isinstance(pred,list):
binary = [predx.copy() for predx in pred]
#print(binary[0].shape)
time1=time.time()
mask_colors=[ { 'mask':pred ,'index':range(1,3),'color':label_dic['water_building'][1:] },
#{ 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] }
]
result_draw = result_merge_sep(image_array0,mask_colors)
time2=time.time()
if outresult:
basename=os.path.splitext( os.path.basename(image_url))[0]
outname=os.path.join(outdir,basename+'_draw.png')
cv2.imwrite(outname,result_draw[:,:,:])
print('##line294: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
def get_illegal_index(contours,hierarchy,water_dilate,overlap_threshold):
out_index=[]
if len(contours)>0:
d=hierarchy[0,:,3]<0 ;
contours = np.array(contours,dtype=object)[d]
imageH,imageW = water_dilate.shape
for ii,cont in enumerate(contours):
build_area=np.zeros((imageH,imageW ))
cv2.fillPoly(build_area,[cont[:,0,:]],1)
area1=np.sum(build_area);area2=np.sum(build_area*water_dilate)
if (area2/area1) >overlap_threshold:
out_index.append(ii)
return out_index
def test_water_building_seperately():
from core.models.dinknet import DinkNet34_MultiOutput
#create_model('DinkNet34_MultiOutput',[2,5])
image_url = 'temp_pics/DJI_0645.JPG'
nclass = [2,2]
outresult=True
weights = 'runs/thWaterBuilding_seperate/BiSeNet_MultiOutput/train/experiment_0/checkpoint.pth'
model = BiSeNet_MultiOutput(nclass)
outdir='temp'
image_dir = '/home/thsw/WJ/data/river_buildings/'
#image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
image_url_list=glob.glob('%s/*'%(image_dir))
segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:1',multiOutput=True)
print('###line307 image cnt:',len(image_url_list))
for i,image_url in enumerate(image_url_list[0:1]) :
image_url = '/home/thsw/WJ/data/river_buildings/DJI_20210904092044_0001_S_output896.jpg'
image_array0 = cv2.imread(image_url)
imageH,imageW,_ = image_array0.shape
image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
pred = segmodel.eval(image_array,outsize=None,smooth_kernel=20)
##画出水体区域
contours, hierarchy = cv2.findContours(pred[0],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
max_id = get_largest_contours(contours);
water = pred[0].copy(); water[:,:] = 0
cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
##画出水体膨胀后的蓝线区域。
kernel = np.ones((100,100),np.uint8)
water_dilate = cv2.dilate(water,kernel,iterations = 1)
contours, hierarchy = cv2.findContours(water_dilate,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
#print('####line310:',contours)
cv2.drawContours(image_array0,contours,-1,(255,0,0),3)
###逐个建筑判断是否与蓝线内区域有交叉。如果交叉面积占本身面积超过0.1,则认为是违法建筑。
contours, hierarchy = cv2.findContours(pred[1],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
outIndex=get_illegal_index(contours,hierarchy,water_dilate,0.1)
for ii in outIndex:
cv2.drawContours(image_array0,contours,ii,(0,0,255),3)
plt.imshow(image_array0);plt.show()
##
time0=time.time()
time1=time.time()
mask_colors=[ { 'mask':pred[0],'index':[1],'color':label_dic['water_building'][1:2]},
{ 'mask':pred[1],'index':[1],'color':label_dic['water_building'][2:3]}
]
result_draw = result_merge_sep(image_array0,mask_colors)
time2=time.time()
if outresult:
basename=os.path.splitext( os.path.basename(image_url))[0]
outname=os.path.join(outdir,basename+'_draw.png')
cv2.imwrite(outname,result_draw[:,:,:])
print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
if __name__=='__main__':
#test()
#test_floater()
#test_water_buildings()
test_water_building_seperately()