|
- import torch
- from models_725.segWaterBuilding import SegModel
- from PIL import Image
- import numpy as np
- import cv2
- import os
- from cv2 import getTickCount, getTickFrequency
- import time
-
- def filer_boundary(probs=None, contours=None, img_n=None):
- contours2 = contours[:]
- n = len(contours)
- band = 0.8
- cimg = np.zeros_like(img_n)
- area = map(cv2.contourArea, contours)
- area_list = list(area)
- if n != 0:
- for i in range(n):
- cimg[:, :, :] = 255
- cv2.drawContours(cimg, contours2, i, color=(0, 0, 0), thickness=-1)
- torch_cimg = torch.from_numpy(cimg).cuda()
- torch_cimg = torch_cimg[:, :, 0] * 1 / 255
- torch_cimg = abs(torch_cimg - 1)
- cimg_logits = torch_cimg.mul(probs)
- sum = torch.sum(cimg_logits)
- sum_pix = sum / area_list[i]
- if sum_pix <= band:
- contours.remove(contours2[i])
- return contours
-
- def predict_lunkuo(impth=None, segmodel=None, filter=False):
- pred, probs = segmodel.eval(image=impth)#####
- preds_squeeze = pred.squeeze(0)
- preds_squeeze[preds_squeeze != 0] = 255
- preds_squeeze = np.array(preds_squeeze.cpu())
- preds_squeeze = np.uint8(preds_squeeze)
- _, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
- contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
- img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR)
- if filter:
- time_guolv = time.time()
- contours = filer_boundary(probs=probs, contours=contours, img_n=img_n)
- print('++++guolv', (time.time() - time_guolv) * 1000)
- img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
- # save_path = 'demo_lunkuo_1700_360640/' + name + '.png'
- # cv2.imwrite(save_path, img2)
- return img2
-
- if __name__ == '__main__':
- impth = '/home/data/lijiwen/wurenjiqifei/images/'
- folders = os.listdir(impth)
- segmodel = SegModel()
- for i in range(len(folders)):
- imgpath = os.path.join(impth, folders[i])
- # name = imgpath.split('/')[-1].split('.')[0]
- time00 = time.time()
- img = Image.open(imgpath).convert('RGB')
- img = np.array(img)
- time11 = time.time()
- img_out = predict_lunkuo(impth=img, segmodel=segmodel, filter=True)
- print('----all_process', (time.time() - time11) * 1000)
|