2022-07-12 11:36:03 +08:00
import torch
import sys , os
2022-07-12 23:15:49 +08:00
sys . path . extend ( [ ' segutils ' ] )
2022-07-12 11:36:03 +08:00
from core . models . bisenet import BiSeNet
from torchvision import transforms
import cv2 , glob
import numpy as np
from core . models . dinknet import DinkNet34
import matplotlib . pyplot as plt
import time
class SegModel ( object ) :
def __init__ ( self , nclass = 2 , weights = None , modelsize = 512 , device = ' cuda:0 ' ) :
#self.args = args
self . model = BiSeNet ( nclass )
#self.model = DinkNet34(nclass)
checkpoint = torch . load ( weights )
self . modelsize = modelsize
self . model . load_state_dict ( checkpoint [ ' model ' ] )
self . device = device
self . model = self . model . to ( self . device )
''' self.composed_transforms = transforms.Compose([
transforms . Normalize ( mean = ( 0.335 , 0.358 , 0.332 ) , std = ( 0.141 , 0.138 , 0.143 ) ) ,
transforms . ToTensor ( ) ] ) '''
self . mean = ( 0.335 , 0.358 , 0.332 )
self . std = ( 0.141 , 0.138 , 0.143 )
def eval ( self , image ) :
time0 = time . time ( )
imageH , imageW , imageC = image . shape
image = self . preprocess_image ( image )
time1 = time . time ( )
self . model . eval ( )
image = image . to ( self . device )
with torch . no_grad ( ) :
output = self . model ( image )
time2 = time . time ( )
pred = output . data . cpu ( ) . numpy ( )
pred = np . argmax ( pred , axis = 1 ) [ 0 ] #得到每行
time3 = time . time ( )
pred = cv2 . resize ( pred . astype ( np . uint8 ) , ( imageW , imageH ) )
time4 = time . time ( )
outstr = ' pre-precess: %.1f ,infer: %.1f ,post-precess: %.1f ,post-resize: %.1f , total: %.1f \n ' % ( self . get_ms ( time1 , time0 ) , self . get_ms ( time2 , time1 ) , self . get_ms ( time3 , time2 ) , self . get_ms ( time4 , time3 ) , self . get_ms ( time4 , time0 ) )
#print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
return pred , outstr
def get_ms ( self , t1 , t0 ) :
return ( t1 - t0 ) * 1000.0
def preprocess_image ( self , image ) :
time0 = time . time ( )
image = cv2 . resize ( image , ( self . modelsize , self . modelsize ) )
time0 = time . time ( )
image = image . astype ( np . float32 )
image / = 255.0
image [ : , : , 0 ] - = self . mean [ 0 ]
image [ : , : , 1 ] - = self . mean [ 1 ]
image [ : , : , 2 ] - = self . mean [ 2 ]
image [ : , : , 0 ] / = self . std [ 0 ]
image [ : , : , 1 ] / = self . std [ 1 ]
image [ : , : , 2 ] / = self . std [ 2 ]
image = cv2 . cvtColor ( image , cv2 . COLOR_RGB2BGR )
#image -= self.mean
#image /= self.std
image = np . transpose ( image , ( 2 , 0 , 1 ) )
image = torch . from_numpy ( image ) . float ( )
image = image . unsqueeze ( 0 )
return image
def get_ms ( t1 , t0 ) :
return ( t1 - t0 ) * 1000.0
def get_largest_contours ( contours ) :
areas = [ cv2 . contourArea ( x ) for x in contours ]
max_area = max ( areas )
max_id = areas . index ( max_area )
return max_id
if __name__ == ' __main__ ' :
image_url = ' /home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG '
nclass = 2
weights = ' ../weights/segmentation/BiSeNet/checkpoint.pth '
segmodel = SegModel ( nclass = nclass , weights = weights )
image_urls = glob . glob ( ' /home/thsw2/WJ/data/THexit/val/images/* ' )
out_dir = ' ../runs/detect/exp2-seg ' ; os . makedirs ( out_dir , exist_ok = True )
for image_url in image_urls [ 0 : 1 ] :
image_url = ' /home/thsw2/WJ/data/THexit/val/images/54(199).JPG '
image_array0 = cv2 . imread ( image_url )
pred = segmodel . eval ( image_array0 )
#plt.figure(1);plt.imshow(pred);
#plt.show()
binary0 = pred . copy ( )
time0 = time . time ( )
contours , hierarchy = cv2 . findContours ( binary0 , cv2 . RETR_TREE , cv2 . CHAIN_APPROX_SIMPLE )
max_id = - 1
if len ( contours ) > 0 :
max_id = get_largest_contours ( contours )
binary0 [ : , : ] = 0
print ( contours [ 0 ] . shape , contours [ 1 ] . shape , contours [ 0 ] )
cv2 . fillPoly ( binary0 , [ contours [ max_id ] [ : , 0 , : ] ] , 1 )
time1 = time . time ( )
#num_labels,_,Areastats,centroids = cv2.connectedComponentsWithStats(binary0,connectivity=4)
time2 = time . time ( )
cv2 . drawContours ( image_array0 , contours , max_id , ( 0 , 255 , 255 ) , 3 )
time3 = time . time ( )
out_url = ' %s / %s ' % ( out_dir , os . path . basename ( image_url ) )
ret = cv2 . imwrite ( out_url , image_array0 )
time4 = time . time ( )
print ( ' image: %s findcontours: %.1f ms , connect: %.1f ms ,draw: %.1f save: %.1f ' % ( os . path . basename ( image_url ) , get_ms ( time1 , time0 ) , get_ms ( time2 , time1 ) , get_ms ( time3 , time2 ) , get_ms ( time4 , time3 ) , ) )
plt . figure ( 0 ) ; plt . imshow ( pred )
plt . figure ( 1 ) ; plt . imshow ( image_array0 )
plt . figure ( 2 ) ; plt . imshow ( binary0 )
plt . show ( )
#print(out_url,ret)