|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import torch
- import sys,os
- sys.path.extend(['segutils'])
- from core.models.bisenet import BiSeNet
- from torchvision import transforms
- import cv2,glob
- import numpy as np
- from core.models.dinknet import DinkNet34
- import matplotlib.pyplot as plt
- import time
- class SegModel(object):
- def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
- #self.args = args
- self.model = BiSeNet(nclass)
- #self.model = DinkNet34(nclass)
- checkpoint = torch.load(weights)
- self.modelsize = modelsize
- self.model.load_state_dict(checkpoint['model'])
- self.device = device
- self.model= self.model.to(self.device)
- '''self.composed_transforms = transforms.Compose([
-
- transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
- transforms.ToTensor()]) '''
- self.mean = (0.335, 0.358, 0.332)
- self.std = (0.141, 0.138, 0.143)
- def eval(self,image):
- time0 = time.time()
- imageH,imageW,imageC = image.shape
- image = self.preprocess_image(image)
- time1 = time.time()
- self.model.eval()
- image = image.to(self.device)
- with torch.no_grad():
- output = self.model(image)
-
- time2 = time.time()
- pred = output.data.cpu().numpy()
- pred = np.argmax(pred, axis=1)[0]#得到每行
- time3 = time.time()
- pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
- time4 = time.time()
- outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
-
- #print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
- return pred,outstr
- def get_ms(self,t1,t0):
- return (t1-t0)*1000.0
- def preprocess_image(self,image):
-
- time0 = time.time()
- image = cv2.resize(image,(self.modelsize,self.modelsize))
- time0 = time.time()
- image = image.astype(np.float32)
- image /= 255.0
-
- image[:,:,0] -=self.mean[0]
- image[:,:,1] -=self.mean[1]
- image[:,:,2] -=self.mean[2]
-
- image[:,:,0] /= self.std[0]
- image[:,:,1] /= self.std[1]
- image[:,:,2] /= self.std[2]
- image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
- #image -= self.mean
- #image /= self.std
- image = np.transpose(image, ( 2, 0, 1))
-
- image = torch.from_numpy(image).float()
- image = image.unsqueeze(0)
-
-
- return image
-
- def get_ms(t1,t0):
- return (t1-t0)*1000.0
-
-
- def get_largest_contours(contours):
- areas = [cv2.contourArea(x) for x in contours]
- max_area = max(areas)
- max_id = areas.index(max_area)
-
- return max_id
-
- if __name__=='__main__':
- image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
- nclass = 2
- #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
- weights = '../weights/BiSeNet/checkpoint.pth'
-
- segmodel = SegModel(nclass=nclass,weights=weights)
-
- image_urls=glob.glob('../../river_demo/images/slope/*')
- out_dir ='../../river_demo/images/results/';
- os.makedirs(out_dir,exist_ok=True)
- for im,image_url in enumerate(image_urls[0:]):
- #image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
- image_array0 = cv2.imread(image_url)
- H,W,C = image_array0.shape
- time_1=time.time()
- pred,outstr = segmodel.eval(image_array0 )
-
- #plt.figure(1);plt.imshow(pred);
- #plt.show()
- binary0 = pred.copy()
-
-
- time0 = time.time()
- contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- max_id = -1
- if len(contours)>0:
- max_id = get_largest_contours(contours)
- binary0[:,:] = 0
- cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
-
- time1 = time.time()
-
-
- time2 = time.time()
-
- cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
- time3 = time.time()
- out_url='%s/%s'%(out_dir,os.path.basename(image_url))
- ret = cv2.imwrite(out_url,image_array0)
- time4 = time.time()
-
- print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
- #print(outstr)
- #plt.figure(0);plt.imshow(pred)
- #plt.figure(1);plt.imshow(image_array0)
- #plt.figure(2);plt.imshow(binary0)
- #plt.show()
-
- #print(out_url,ret)
-
-
-
-
-
-
-
-
-
|