AIlib2/segutils/segmodel.py

166 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import sys,os
sys.path.extend(['../AIlib2/segutils'])
from model_stages import BiSeNet_STDC
from torchvision import transforms
import cv2,glob
import numpy as np
from core.models.dinknet import DinkNet34
import matplotlib.pyplot as plt
import time
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms as transforms
class SegModel(object):
def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
#self.args = args
self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=nclass,
use_boundary_2=False, use_boundary_4=False,
use_boundary_8=True, use_boundary_16=False,
use_conv_last=False)
self.device = device
self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
self.model= self.model.to(self.device)
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def eval(self,image):
time0 = time.time()
imageH, imageW, _ = image.shape
image = self.RB_convert(image)
print('line32: image:',image[100,100,:],image.shape )
img = self.preprocess_image(image)
if self.device != 'cpu':
imgs = img.to(self.device)
else:imgs=img
time1 = time.time()
self.model.eval()
with torch.no_grad():
print(' segmodel.py line35:',len(imgs),imgs[0].shape , imgs[0,:,100,100])
output = self.model(imgs)
time2 = time.time()
pred = output.data.cpu().numpy()
pred = np.argmax(pred, axis=1)[0]#得到每行
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
return pred,outstr
def eval_zyy(self,image):###此函数采用的预处理方法和zyy跑出来的结果一致
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
])
time0 = time.time()
imageH, imageW, _ = image.shape
imgs= Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
imgs = self.to_tensor(imgs)
if self.device != 'cpu':
imgs = imgs.to(self.device)
imgs = torch.unsqueeze(imgs, dim=0)
imgs = F.interpolate(imgs, [ 360,640 ], mode='bilinear', align_corners=True)
time1 = time.time()
self.model.eval()
with torch.no_grad():
print('###line 64 img:',imgs[0].shape, imgs[0][0,10:12,10:12])
output = self.model(imgs)
print('###line69:',output.size())
time2 = time.time()
pred = output.data.cpu().numpy()
pred = np.argmax(pred, axis=1)[0]#得到每行
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
print('####line78:',pred.shape,np.max(pred),np.min(pred))
return pred,outstr
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def preprocess_image(self,image):
image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32)
image /= 255.0
image[:, :, 0] -= self.mean[0]
image[:, :, 1] -= self.mean[1]
image[:, :, 2] -= self.mean[2]
image[:, :, 0] /= self.std[0]
image[:, :, 1] /= self.std[1]
image[:, :, 2] /= self.std[2]
image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
return image
def RB_convert(self,image):
image_c = image.copy()
image_c[:,:,0] = image[:,:,2]
image_c[:,:,2] = image[:,:,0]
return image_c
def get_ms(t1,t0):
return (t1-t0)*1000.0
def get_largest_contours(contours):
areas = [cv2.contourArea(x) for x in contours]
max_area = max(areas)
max_id = areas.index(max_area)
return max_id
if __name__=='__main__':
impth = '../../river_demo/images/slope/'
outpth= 'results'
folders = os.listdir(impth)
weights = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.pth'
segmodel = SegModel(nclass=2,weights=weights)
for i in range(len(folders)):
imgpath = os.path.join(impth, folders[i])
time0 = time.time()
#img = Image.open(imgpath).convert('RGB')
img = cv2.imread(imgpath)
img = np.array(img)
time1 = time.time()
pred, outstr = segmodel.eval(image=img)#####
time2 = time.time()
binary0 = pred.copy()
contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
time3 = time.time()
max_id = -1
if len(contours)>0:
max_id = get_largest_contours(contours)
binary0[:,:] = 0
cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
cv2.drawContours(img,contours,max_id,(0,255,255),3)
time4 = time.time()
#img_n = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
cv2.imwrite( os.path.join( outpth,folders[i] ) ,img )
time5 = time.time()
print('image:%d ,infer:%.1f ms,findcontours:%.1f ms, draw:%.1f, total:%.1f'%(i,get_ms(time2,time1),get_ms(time3,time2),get_ms(time4,time3),get_ms(time4,time1)))