import torch import sys,os sys.path.extend(['segutils']) from core.models.bisenet import BiSeNet from torchvision import transforms import cv2,glob import numpy as np from core.models.dinknet import DinkNet34 import matplotlib.pyplot as plt import time class SegModel(object): def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'): #self.args = args self.model = BiSeNet(nclass) #self.model = DinkNet34(nclass) checkpoint = torch.load(weights) self.modelsize = modelsize self.model.load_state_dict(checkpoint['model']) self.device = device self.model= self.model.to(self.device) '''self.composed_transforms = transforms.Compose([ transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)), transforms.ToTensor()]) ''' self.mean = (0.335, 0.358, 0.332) self.std = (0.141, 0.138, 0.143) def eval(self,image): time0 = time.time() imageH,imageW,imageC = image.shape image = self.preprocess_image(image) time1 = time.time() self.model.eval() image = image.to(self.device) with torch.no_grad(): output = self.model(image) time2 = time.time() pred = output.data.cpu().numpy() pred = np.argmax(pred, axis=1)[0]#得到每行 time3 = time.time() pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH)) time4 = time.time() outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ) #print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )) return pred,outstr def get_ms(self,t1,t0): return (t1-t0)*1000.0 def preprocess_image(self,image): time0 = time.time() image = cv2.resize(image,(self.modelsize,self.modelsize)) time0 = time.time() image = image.astype(np.float32) image /= 255.0 image[:,:,0] -=self.mean[0] image[:,:,1] -=self.mean[1] image[:,:,2] -=self.mean[2] image[:,:,0] /= self.std[0] image[:,:,1] /= self.std[1] image[:,:,2] /= self.std[2] image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR) #image -= self.mean #image /= self.std image = np.transpose(image, ( 2, 0, 1)) image = torch.from_numpy(image).float() image = image.unsqueeze(0) return image def get_ms(t1,t0): return (t1-t0)*1000.0 def get_largest_contours(contours): areas = [cv2.contourArea(x) for x in contours] max_area = max(areas) max_id = areas.index(max_area) return max_id if __name__=='__main__': image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG' nclass = 2 weights = '../weights/segmentation/BiSeNet/checkpoint.pth' segmodel = SegModel(nclass=nclass,weights=weights) image_urls=glob.glob('/home/thsw2/WJ/data/THexit/val/images/*') out_dir ='../runs/detect/exp2-seg';os.makedirs(out_dir,exist_ok=True) for image_url in image_urls[0:1]: image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG' image_array0 = cv2.imread(image_url) pred = segmodel.eval(image_array0 ) #plt.figure(1);plt.imshow(pred); #plt.show() binary0 = pred.copy() time0 = time.time() contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) max_id = -1 if len(contours)>0: max_id = get_largest_contours(contours) binary0[:,:] = 0 print(contours[0].shape,contours[1].shape,contours[0]) cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1) time1 = time.time() #num_labels,_,Areastats,centroids = cv2.connectedComponentsWithStats(binary0,connectivity=4) time2 = time.time() cv2.drawContours(image_array0,contours,max_id,(0,255,255),3) time3 = time.time() out_url='%s/%s'%(out_dir,os.path.basename(image_url)) ret = cv2.imwrite(out_url,image_array0) time4 = time.time() print('image:%s findcontours:%.1f ms , connect:%.1f ms ,draw:%.1f save:%.1f'%(os.path.basename(image_url),get_ms(time1,time0),get_ms(time2,time1), get_ms(time3,time2),get_ms(time4,time3), ) ) plt.figure(0);plt.imshow(pred) plt.figure(1);plt.imshow(image_array0) plt.figure(2);plt.imshow(binary0) plt.show() #print(out_url,ret)